this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

session fixes

+108 -67
+108 -67
atproto/auth/oauth/session.go
··· 20 20 "github.com/google/go-querystring/query" 21 21 ) 22 22 23 - type RefreshCallback = func(ctx context.Context, data ClientSessionData) 23 + type PersistSessionCallback = func(ctx context.Context, data *ClientSessionData) 24 24 25 25 // Persisted information about an OAuth session. Used to resume an active session. 26 26 type ClientSessionData struct { ··· 62 62 Data *ClientSessionData 63 63 DpopPrivateKey crypto.PrivateKey 64 64 65 - RefreshCallback RefreshCallback 65 + PersistSessionCallback PersistSessionCallback 66 66 67 67 // Lock which protects concurrent access to session data (eg, access and refresh tokens) 68 68 lk sync.RWMutex 69 69 } 70 70 71 - func (sess *ClientSession) RefreshTokens(ctx context.Context) error { 71 + // Requests new tokens from auth server, and returns the new access token on success. 72 + // 73 + // Internally takes a lock on session data around the entire refresh process, including retries. Persists data using PersistSessionCallback if configured. 74 + func (sess *ClientSession) RefreshTokens(ctx context.Context) (string, error) { 75 + sess.lk.Lock() 76 + defer sess.lk.Unlock() 72 77 73 78 body := RefreshTokenRequest{ 74 79 ClientID: sess.Config.ClientID, 75 80 GrantType: "authorization_code", 76 81 RefreshToken: sess.Data.RefreshToken, 77 82 } 83 + tokenURL := sess.Data.AuthServerTokenEndpoint 78 84 79 85 if sess.Config.IsConfidential() { 80 86 clientAssertion, err := sess.Config.NewClientAssertion(sess.Data.AuthServerURL) 81 87 if err != nil { 82 - return err 88 + return "", err 83 89 } 84 90 body.ClientAssertionType = &CLIENT_ASSERTION_JWT_BEARER 85 91 body.ClientAssertion = &clientAssertion ··· 87 93 88 94 vals, err := query.Values(body) 89 95 if err != nil { 90 - return err 96 + return "", err 91 97 } 92 98 bodyBytes := []byte(vals.Encode()) 93 99 94 - // XXX: persist this back to the data? 95 - dpopServerNonce := sess.Data.DpopAuthServerNonce 96 - tokenURL := sess.Data.AuthServerTokenEndpoint 97 - 98 100 var resp *http.Response 99 101 for range 2 { 100 - dpopJWT, err := NewAuthDPoP("POST", tokenURL, dpopServerNonce, sess.DpopPrivateKey) 102 + dpopJWT, err := NewAuthDPoP("POST", sess.Data.AuthServerTokenEndpoint, sess.Data.DpopAuthServerNonce, sess.DpopPrivateKey) 101 103 if err != nil { 102 - return err 104 + return "", err 103 105 } 104 106 105 107 req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, bytes.NewBuffer(bodyBytes)) 106 108 if err != nil { 107 - return err 109 + return "", err 108 110 } 109 111 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 110 112 req.Header.Set("DPoP", dpopJWT) 111 113 112 114 resp, err = sess.Client.Do(req) 113 115 if err != nil { 114 - return err 116 + return "", err 115 117 } 116 118 117 - // check if a nonce was provided 118 - dpopServerNonce = resp.Header.Get("DPoP-Nonce") 119 - if resp.StatusCode == 400 && dpopServerNonce != "" { 120 - // TODO: also check that body is JSON with an 'error' string field value of 'use_dpop_nonce' 119 + // always check if a new DPoP nonce was provided, and proactively update session data (even if there was not an explicit error) 120 + dpopNonceHdr := resp.Header.Get("DPoP-Nonce") 121 + if dpopNonceHdr != "" && dpopNonceHdr != sess.Data.DpopAuthServerNonce { 122 + sess.Data.DpopAuthServerNonce = dpopNonceHdr 123 + } 124 + 125 + // check for an error condition caused by an out of date DPoP nonce 126 + // note that the HTTP status code would be 400 Bad Request on token endpoint, not 401 Unauthorized like it would be on Resource Server requests 127 + if resp.StatusCode == http.StatusBadRequest && dpopNonceHdr != "" { 128 + 129 + // parse the error body to confirm the error type 121 130 var errResp map[string]any 122 131 if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { 123 - slog.Warn("initial token request failed", "authServer", tokenURL, "err", err, "statusCode", resp.StatusCode) 124 - } else { 125 - slog.Warn("initial token request failed", "authServer", tokenURL, "resp", errResp, "statusCode", resp.StatusCode) 132 + slog.Warn("token refresh failed, and could not parse response body", "authServer", tokenURL, "err", err, "statusCode", resp.StatusCode) 133 + resp.Body.Close() 134 + return "", fmt.Errorf("token refresh failed: HTTP %d", resp.StatusCode) 135 + } else if errResp["error"] != "use_dpop_nonce" { 136 + slog.Warn("token refresh failed", "authServer", tokenURL, "body", errResp, "statusCode", resp.StatusCode) 137 + return "", fmt.Errorf("token refresh failed: %s", errResp["error"]) 126 138 } 127 139 128 - // loop around try again 140 + // already updated nonce value above; loop around and try again 141 + // NOTE: having already parsed the body means that the error handling below could fail if we call out of 'for' loop 129 142 resp.Body.Close() 130 143 continue 131 144 } 145 + 132 146 // otherwise process result 133 147 break 134 148 } 135 149 136 150 defer resp.Body.Close() 137 - if resp.StatusCode != 200 { 151 + if resp.StatusCode != http.StatusOK { 138 152 var errResp map[string]any 139 153 if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { 140 - slog.Warn("initial token request failed", "authServer", tokenURL, "err", err, "statusCode", resp.StatusCode) 141 - } else { 142 - slog.Warn("initial token request failed", "authServer", tokenURL, "resp", errResp, "statusCode", resp.StatusCode) 154 + slog.Warn("token refresh failed", "authServer", tokenURL, "err", err, "statusCode", resp.StatusCode) 155 + return "", fmt.Errorf("token refresh failed: HTTP %d", resp.StatusCode) 143 156 } 144 - return fmt.Errorf("initial token request failed: HTTP %d", resp.StatusCode) 157 + slog.Warn("token refresh failed", "authServer", tokenURL, "body", errResp, "statusCode", resp.StatusCode) 158 + return "", fmt.Errorf("token refresh failed: %s", errResp["error"]) 145 159 } 146 160 147 161 var tokenResp TokenResponse 148 162 if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { 149 - return fmt.Errorf("token response failed to decode: %w", err) 163 + return "", fmt.Errorf("token response failed to decode: %w", err) 150 164 } 151 - // XXX: more validation of response? 165 + // TODO: more validation of token refresh response? 152 166 153 167 sess.Data.AccessToken = tokenResp.AccessToken 154 168 sess.Data.RefreshToken = tokenResp.RefreshToken 155 169 156 - return nil 170 + // persist updated data (tokens and possibly nonce) 171 + if sess.PersistSessionCallback != nil { 172 + sess.PersistSessionCallback(ctx, sess.Data) 173 + } 174 + 175 + return sess.Data.AccessToken, nil 157 176 } 158 177 159 - func (sess *ClientSession) NewAccessDPoP(method, reqURL string) (string, error) { 178 + // Constructs and signs a DPoP JWT to include in request header to Host (aka Resource Server, aka PDS). These tokens are different from those used with Auth Server token endpoints (even if the PDS is filling both roles) 179 + func (sess *ClientSession) NewHostDPoP(method, reqURL string) (string, error) { 180 + sess.lk.RLock() 181 + defer sess.lk.RUnlock() 160 182 161 183 ath := S256CodeChallenge(sess.Data.AccessToken) 162 184 claims := dpopClaims{ ··· 205 227 return u2.String() 206 228 } 207 229 230 + // Parses a WWW-Authenticate response header to see if DPoP nonce update is indicated 231 + func isNonceUpdateHeader(hdr string) bool { 232 + // Example from RFC9449: 233 + // WWW-Authenticate: DPoP error="use_dpop_nonce", error_description="Resource server requires nonce in DPoP proof" 234 + return strings.Contains(hdr, "error=\"use_dpop_nonce\"") 235 + } 236 + 237 + // Parses a WWW-Authenticate response header to see if access token has expired (needs refresh) 238 + func isExpiredAccessTokenHeader(hdr string) bool { 239 + // Example from OAuth 2.1 draft: 240 + // WWW-Authenticate: Bearer error="invalid_token" error_description="The access token expired" 241 + // TODO: should this also look for "expired"? 242 + return strings.Contains(hdr, "error=\"invalid_token\"") 243 + } 244 + 245 + // Sends API request to OAuth Resource Server (PDS), using access token and DPoP. 246 + // 247 + // Automatically handles DPoP nonce updates and token refresh as needed, based on the response status code and `WWW-Authenticate` header. 208 248 func (sess *ClientSession) DoWithAuth(c *http.Client, req *http.Request, endpoint syntax.NSID) (*http.Response, error) { 209 249 210 250 durl := dpopURL(req.URL) 211 251 252 + //accessToken, dpopNonce := sess.GetHostData() 212 253 // XXX: fetch with mutex lock 213 254 accessToken := sess.Data.AccessToken 214 - originalNonce := sess.Data.DpopHostNonce 215 - dpopNonce := originalNonce 255 + dpopNonce := sess.Data.DpopHostNonce 216 256 257 + // this method may need to retry twice, once for DPoP nonce update and once for token refresh 217 258 var resp *http.Response 218 259 for range 3 { 219 - dpopJWT, err := sess.NewAccessDPoP(req.Method, durl) 260 + dpopJWT, err := sess.NewHostDPoP(req.Method, durl) 220 261 if err != nil { 221 262 return nil, err 222 263 } ··· 228 269 return nil, err 229 270 } 230 271 231 - // on success, or most errors, just return HTTP response 232 - if resp.StatusCode != http.StatusBadRequest || !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { 272 + // on Success, or many types of error, just return HTTP response 273 + // "Unauthorized" is HTTP status code 401 274 + if resp.StatusCode != http.StatusUnauthorized || resp.Header.Get("WWW-Authenticate") == "" { 233 275 return resp, nil 234 276 } 235 277 236 - // parse the error response body (JSON) and check the error name 237 - defer resp.Body.Close() 238 - var eb client.ErrorBody 239 - if err := json.NewDecoder(resp.Body).Decode(&eb); err != nil { 240 - return nil, &client.APIError{StatusCode: resp.StatusCode} 241 - } 278 + authHdr := resp.Header.Get("WWW-Authenticate") 279 + dpopNonceHdr := resp.Header.Get("DPoP-Nonce") 242 280 243 - // if DPoP nonce was stale, simply retry 244 - if eb.Name == "use_dpop_nonce" && resp.Header.Get("DPoP-Nonce") != "" { 245 - dpopNonce = resp.Header.Get("DPoP-Nonce") 246 - if dpopNonce != originalNonce { 247 - // XXX: persist new nonce value via callback 281 + // if DPoP nonce changed, update and retry request 282 + if isNonceUpdateHeader(authHdr) && dpopNonceHdr != "" { 283 + // TODO: validate or normalize dpopNonceHdr in some way? eg minimum length 284 + if dpopNonceHdr == dpopNonce { 285 + return nil, fmt.Errorf("OAuth PDS DPoP nonce failure, but no new nonce supplied") 248 286 } 249 - 287 + // XXX: persist new nonce value via callback 288 + sess.Data.DpopHostNonce = dpopNonceHdr 289 + dpopNonce = dpopNonceHdr 290 + // retry request 250 291 retry := req.Clone(req.Context()) 251 292 if req.GetBody != nil { 252 293 retry.Body, err = req.GetBody() 253 294 if err != nil { 254 - return nil, fmt.Errorf("API request retry GetBody failed: %w", err) 295 + return nil, fmt.Errorf("GetBody failed when retrying API request: %w", err) 255 296 } 256 297 } 257 298 req = retry 258 299 continue 259 300 } 260 301 261 - // if this is anything other than an expired token, bail out now 262 - if eb.Name != "ExpiredToken" { 263 - return nil, eb.APIError(resp.StatusCode) 264 - } 302 + // if access token expired, refresh and retry 303 + if isExpiredAccessTokenHeader(authHdr) { 304 + accessToken, err = sess.RefreshTokens(req.Context()) 305 + if err != nil { 306 + return nil, fmt.Errorf("failed to refresh OAuth tokens: %w", err) 307 + } 265 308 266 - if err := sess.RefreshTokens(req.Context()); err != nil { 267 - return nil, err 309 + retry := req.Clone(req.Context()) 310 + if req.GetBody != nil { 311 + retry.Body, err = req.GetBody() 312 + if err != nil { 313 + return nil, fmt.Errorf("GetBody failed when retrying API request: %w", err) 314 + } 315 + } 316 + req = retry 317 + continue 268 318 } 269 319 270 - // XXX: fetch with mutex lock 271 - accessToken = sess.Data.AccessToken 272 - 273 - retry := req.Clone(req.Context()) 274 - if req.GetBody != nil { 275 - retry.Body, err = req.GetBody() 276 - if err != nil { 277 - return nil, fmt.Errorf("API request retry GetBody failed: %w", err) 278 - } 279 - } 280 - req = retry 281 - continue 320 + // otherwise, this was some other type of auth failure; just return the full response 321 + // NOTE: in theory we could return an APIError here instead 322 + return resp, nil 282 323 } 283 324 284 - return nil, fmt.Errorf("OAuth client ran out of retries") 325 + return nil, fmt.Errorf("OAuth client ran out of request retries") 285 326 } 286 327 287 328 func (sess *ClientSession) APIClient() *client.APIClient {