Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

Harden auth proxy with issuer metadata validation, safe key rotation, and SSRF-safe outbound requests

+1143 -447
+47 -7
README.md
··· 53 53 | `AUTH_ALLOWED_ORIGINS` | No | `*` | CORS allowed origins | 54 54 | `AUTH_RATE_LIMIT_PER_IP` | No | `10` | Max requests per IP per minute on `/oauth/token` and `/oauth/par` (0 to disable) | 55 55 | `AUTH_RATE_LIMIT_GLOBAL` | No | `100` | Max total requests per minute on `/oauth/token` and `/oauth/par` (0 to disable) | 56 + | `AUTH_TRUST_PROXY_HEADERS` | No | `false` | Trust `X-Forwarded-For` / `X-Real-IP` for per-IP rate limiting when deployed behind a trusted reverse proxy | 56 57 57 58 ## How It Works 58 59 ··· 68 69 └─────────────┘ └──────────────────────┘ └─────────────────────┘ 69 70 ``` 70 71 71 - The proxy is stateless — no database, no session storage, no user data. It holds a private signing key and uses it to authenticate token requests on behalf of your app. 72 + The proxy is stateless — no database, no session storage, no user data. It holds the client signing key material, verifies issuer metadata before each proxied flow, and uses the selected key to authenticate token requests on behalf of your app. 72 73 73 74 1. Native app initiates OAuth and gets an auth code 74 75 2. App sends the auth code to the proxy (`POST /oauth/token`) ··· 81 82 82 83 DPoP proofs are generated on the device and forwarded through the proxy transparently. 83 84 85 + The proxy returns the selected signing key via the `Auth-Proxy-Key-ID` response header. Clients should persist that value and send it back as `key_id` on later `/oauth/token` refresh requests so sessions keep using the same key across rotations. 86 + 84 87 ## API Endpoints 85 88 86 89 | Method | Path | Description | ··· 90 93 | `POST` | `/oauth/par` | Proxy Pushed Authorization Requests | 91 94 | `GET` | `/health` | Health check | 92 95 96 + ### `POST /oauth/token` 97 + 98 + Request body: 99 + 100 + ```json 101 + { 102 + "token_endpoint": "https://bsky.social/oauth/token", 103 + "issuer": "https://bsky.social", 104 + "key_id": "atproto-auth-2", 105 + "grant_type": "refresh_token", 106 + "refresh_token": "<refresh_token>" 107 + } 108 + ``` 109 + 110 + `key_id` is optional, but clients should send it once they have seen an `Auth-Proxy-Key-ID` response header. During a rotation window, the proxy can also retry an older configured key automatically if the active key returns `invalid_client`. 111 + 112 + ### `POST /oauth/par` 113 + 114 + Request body: 115 + 116 + ```json 117 + { 118 + "par_endpoint": "https://bsky.social/oauth/par", 119 + "issuer": "https://bsky.social", 120 + "key_id": "atproto-auth-2", 121 + "login_hint": "user.bsky.social", 122 + "scope": "atproto transition:generic", 123 + "code_challenge": "<pkce_challenge>", 124 + "code_challenge_method": "S256", 125 + "state": "<state>", 126 + "redirect_uri": "yourapp://oauth/callback" 127 + } 128 + ``` 129 + 130 + The proxy validates that `issuer`, `token_endpoint`, and `par_endpoint` match the authorization server’s well-known metadata before forwarding the request. 131 + 93 132 ## Client Metadata Changes 94 133 95 134 Update your app's `client-metadata.json` to use the proxy: ··· 125 164 1. Generate a new key pair with a new `kid` (e.g., `atproto-auth-2`) 126 165 2. Set `AUTH_OLD_PRIVATE_KEY` and `AUTH_OLD_KEY_ID` to your current key values 127 166 3. Set `AUTH_PRIVATE_KEY` and `AUTH_KEY_ID` to the new key 128 - 4. Deploy — the JWKS now serves both keys; new assertions use the new key 167 + 4. Deploy — the JWKS now serves both keys; new PAR and token assertions use the active key by default 129 168 5. After 24+ hours, remove `AUTH_OLD_PRIVATE_KEY` and `AUTH_OLD_KEY_ID` 130 169 131 - The active key (`AUTH_PRIVATE_KEY`) is always used for signing new assertions. The old key is only published in the JWKS so auth servers can still verify existing sessions. 170 + Clients should persist the `Auth-Proxy-Key-ID` response header from PAR/token responses and send it back as `key_id` for later token exchanges and refreshes. During the overlap window, the proxy also retries the old configured key automatically if an otherwise-valid token request gets `invalid_client` from the auth server. 132 171 133 172 ## Security Considerations 134 173 135 - - **Token endpoint validation**: The proxy validates that upstream URLs use HTTPS, resolves hostnames via DNS to reject private addresses, and rejects private/localhost/link-local addresses to prevent SSRF 136 - - **Redirect protection**: Upstream HTTP redirects are validated to prevent redirection to private addresses 137 - - **Request timeout**: Upstream requests have a 30-second timeout to prevent slow-loris attacks 174 + - **Issuer metadata validation**: The proxy resolves the authorization server’s well-known metadata and verifies the requested token/PAR endpoint matches the declared issuer metadata before sending any signed request 175 + - **Hardened outbound HTTP**: Metadata and upstream requests use a public-IP-only transport, reject localhost/private/reserved ranges, and validate redirects to prevent SSRF and DNS rebinding 176 + - **Bounded upstream reads**: Metadata and proxied token/PAR responses are size-limited in memory to avoid oversized-response abuse 177 + - **Request timeout**: Metadata and upstream requests have explicit timeouts to prevent slow-loris attacks 138 178 - **No token logging**: Token values, auth codes, and refresh tokens are never logged 139 179 - **HTTPS required**: The proxy must be served over HTTPS in production (handled automatically by Railway/Fly.io) 140 180 - **DPoP passthrough**: The proxy never sees DPoP private keys — proofs are between the device and auth server 141 - - **Rate limiting**: Per-IP and global rate limits on `/oauth/token` and `/oauth/par` (configurable, defaults to 10/min per IP, 100/min global) 181 + - **Rate limiting**: Per-IP and global rate limits on `/oauth/token` and `/oauth/par` (configurable, defaults to 10/min per IP, 100/min global). Proxy headers are only trusted when `AUTH_TRUST_PROXY_HEADERS=true` 142 182 - **Stateless**: No database, no user data stored — the only secret is the client signing key in an environment variable 143 183 144 184 ## License
+174
authserver.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "io" 8 + "net/http" 9 + "net/url" 10 + "slices" 11 + "sync" 12 + "time" 13 + ) 14 + 15 + const ( 16 + authServerMetadataCacheTTL = 5 * time.Minute 17 + maxMetadataResponseBytes = 256 << 10 18 + maxUpstreamResponseBodySize = 1 << 20 19 + ) 20 + 21 + type authServerMetadata struct { 22 + Issuer string `json:"issuer"` 23 + TokenEndpoint string `json:"token_endpoint"` 24 + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` 25 + TokenEndpointAuthSigningAlgsSupported []string `json:"token_endpoint_auth_signing_alg_values_supported"` 26 + RequirePushedAuthorizationRequests bool `json:"require_pushed_authorization_requests"` 27 + PushedAuthorizationRequestEndpoint string `json:"pushed_authorization_request_endpoint"` 28 + } 29 + 30 + type cachedAuthServerMetadata struct { 31 + metadata *authServerMetadata 32 + expiresAt time.Time 33 + } 34 + 35 + var ( 36 + authServerMetadataCache = struct { 37 + mu sync.Mutex 38 + entries map[string]cachedAuthServerMetadata 39 + }{ 40 + entries: make(map[string]cachedAuthServerMetadata), 41 + } 42 + 43 + metadataClient = newPublicHTTPClient(10 * time.Second) 44 + ) 45 + 46 + func clearAuthServerMetadataCache() { 47 + authServerMetadataCache.mu.Lock() 48 + defer authServerMetadataCache.mu.Unlock() 49 + authServerMetadataCache.entries = make(map[string]cachedAuthServerMetadata) 50 + } 51 + 52 + func ResolveAuthServerMetadata(ctx context.Context, issuer string) (*authServerMetadata, error) { 53 + issuerURL, err := ValidateIssuer(issuer) 54 + if err != nil { 55 + return nil, invalidRequestError("invalid issuer") 56 + } 57 + 58 + now := time.Now() 59 + 60 + authServerMetadataCache.mu.Lock() 61 + entry, ok := authServerMetadataCache.entries[issuer] 62 + if ok && now.Before(entry.expiresAt) { 63 + authServerMetadataCache.mu.Unlock() 64 + return entry.metadata, nil 65 + } 66 + authServerMetadataCache.mu.Unlock() 67 + 68 + metadataURL := issuerURL.ResolveReference(&url.URL{Path: "/.well-known/oauth-authorization-server"}) 69 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL.String(), nil) 70 + if err != nil { 71 + return nil, upstreamRequestError("failed to create metadata request") 72 + } 73 + 74 + resp, err := metadataClient.Do(req) 75 + if err != nil { 76 + return nil, upstreamRequestError("failed to fetch authorization server metadata") 77 + } 78 + defer resp.Body.Close() 79 + 80 + if resp.StatusCode != http.StatusOK { 81 + return nil, upstreamRequestError(fmt.Sprintf("authorization server metadata returned HTTP %d", resp.StatusCode)) 82 + } 83 + 84 + body, err := io.ReadAll(io.LimitReader(resp.Body, maxMetadataResponseBytes+1)) 85 + if err != nil { 86 + return nil, upstreamRequestError("failed to read authorization server metadata") 87 + } 88 + if len(body) > maxMetadataResponseBytes { 89 + return nil, upstreamRequestError("authorization server metadata response was too large") 90 + } 91 + 92 + var metadata authServerMetadata 93 + if err := json.Unmarshal(body, &metadata); err != nil { 94 + return nil, upstreamRequestError("authorization server metadata was not valid JSON") 95 + } 96 + 97 + if err := metadata.Validate(issuer); err != nil { 98 + return nil, invalidRequestError(err.Error()) 99 + } 100 + 101 + authServerMetadataCache.mu.Lock() 102 + authServerMetadataCache.entries[issuer] = cachedAuthServerMetadata{ 103 + metadata: &metadata, 104 + expiresAt: now.Add(authServerMetadataCacheTTL), 105 + } 106 + authServerMetadataCache.mu.Unlock() 107 + 108 + return &metadata, nil 109 + } 110 + 111 + func (m *authServerMetadata) Validate(expectedIssuer string) error { 112 + if _, err := ValidateIssuer(m.Issuer); err != nil { 113 + return fmt.Errorf("issuer metadata contained an invalid issuer") 114 + } 115 + if m.Issuer != expectedIssuer { 116 + return fmt.Errorf("issuer metadata did not match the requested issuer") 117 + } 118 + if _, err := validateEndpointURL(m.TokenEndpoint); err != nil { 119 + return fmt.Errorf("issuer metadata contained an invalid token_endpoint") 120 + } 121 + if !slices.Contains(m.TokenEndpointAuthMethodsSupported, "private_key_jwt") { 122 + return fmt.Errorf("issuer metadata does not support private_key_jwt") 123 + } 124 + if !slices.Contains(m.TokenEndpointAuthSigningAlgsSupported, "ES256") { 125 + return fmt.Errorf("issuer metadata does not support ES256 client assertions") 126 + } 127 + if m.PushedAuthorizationRequestEndpoint != "" { 128 + if _, err := validateEndpointURL(m.PushedAuthorizationRequestEndpoint); err != nil { 129 + return fmt.Errorf("issuer metadata contained an invalid pushed_authorization_request_endpoint") 130 + } 131 + } 132 + if m.RequirePushedAuthorizationRequests && m.PushedAuthorizationRequestEndpoint == "" { 133 + return fmt.Errorf("issuer metadata requires PAR but does not advertise a PAR endpoint") 134 + } 135 + 136 + return nil 137 + } 138 + 139 + func ValidateTokenEndpointForIssuer(ctx context.Context, issuer string, tokenEndpoint string) error { 140 + if err := ValidateTokenEndpoint(tokenEndpoint); err != nil { 141 + return invalidRequestError("invalid token_endpoint") 142 + } 143 + 144 + metadata, err := ResolveAuthServerMetadata(ctx, issuer) 145 + if err != nil { 146 + return err 147 + } 148 + 149 + if metadata.TokenEndpoint != tokenEndpoint { 150 + return invalidRequestError("token_endpoint does not match issuer metadata") 151 + } 152 + 153 + return nil 154 + } 155 + 156 + func ValidatePAREndpointForIssuer(ctx context.Context, issuer string, parEndpoint string) error { 157 + if err := ValidatePAREndpoint(parEndpoint); err != nil { 158 + return invalidRequestError("invalid par_endpoint") 159 + } 160 + 161 + metadata, err := ResolveAuthServerMetadata(ctx, issuer) 162 + if err != nil { 163 + return err 164 + } 165 + 166 + if metadata.PushedAuthorizationRequestEndpoint == "" { 167 + return invalidRequestError("issuer does not advertise a pushed_authorization_request_endpoint") 168 + } 169 + if metadata.PushedAuthorizationRequestEndpoint != parEndpoint { 170 + return invalidRequestError("par_endpoint does not match issuer metadata") 171 + } 172 + 173 + return nil 174 + }
+29 -18
config.go
··· 7 7 ) 8 8 9 9 type Config struct { 10 - PrivateKeyPEM string 11 - ClientID string 12 - KeyID string 13 - OldPrivateKeyPEM string 14 - OldKeyID string 15 - Bind string 16 - AllowedOrigins string 17 - RateLimitPerIP int 18 - RateLimitGlobal int 10 + PrivateKeyPEM string 11 + ClientID string 12 + KeyID string 13 + OldPrivateKeyPEM string 14 + OldKeyID string 15 + Bind string 16 + AllowedOrigins string 17 + RateLimitPerIP int 18 + RateLimitGlobal int 19 + TrustProxyHeaders bool 19 20 } 20 21 21 22 func LoadConfig() (*Config, error) { ··· 65 66 rateLimitGlobal = n 66 67 } 67 68 69 + trustProxyHeaders := false 70 + if v := os.Getenv("AUTH_TRUST_PROXY_HEADERS"); v != "" { 71 + b, err := strconv.ParseBool(v) 72 + if err != nil { 73 + return nil, fmt.Errorf("AUTH_TRUST_PROXY_HEADERS must be a boolean") 74 + } 75 + trustProxyHeaders = b 76 + } 77 + 68 78 return &Config{ 69 - PrivateKeyPEM: privateKey, 70 - ClientID: clientID, 71 - KeyID: keyID, 72 - OldPrivateKeyPEM: oldPrivateKey, 73 - OldKeyID: oldKeyID, 74 - Bind: bind, 75 - AllowedOrigins: allowedOrigins, 76 - RateLimitPerIP: rateLimitPerIP, 77 - RateLimitGlobal: rateLimitGlobal, 79 + PrivateKeyPEM: privateKey, 80 + ClientID: clientID, 81 + KeyID: keyID, 82 + OldPrivateKeyPEM: oldPrivateKey, 83 + OldKeyID: oldKeyID, 84 + Bind: bind, 85 + AllowedOrigins: allowedOrigins, 86 + RateLimitPerIP: rateLimitPerIP, 87 + RateLimitGlobal: rateLimitGlobal, 88 + TrustProxyHeaders: trustProxyHeaders, 78 89 }, nil 79 90 }
+52
errors.go
··· 1 + package main 2 + 3 + import ( 4 + "encoding/json" 5 + "net/http" 6 + ) 7 + 8 + const authProxyKeyIDHeader = "Auth-Proxy-Key-ID" 9 + 10 + type apiError struct { 11 + Status int 12 + Code string 13 + Description string 14 + } 15 + 16 + func (e *apiError) Error() string { 17 + return e.Description 18 + } 19 + 20 + func invalidRequestError(description string) *apiError { 21 + return &apiError{ 22 + Status: http.StatusBadRequest, 23 + Code: "invalid_request", 24 + Description: description, 25 + } 26 + } 27 + 28 + func upstreamRequestError(description string) *apiError { 29 + return &apiError{ 30 + Status: http.StatusBadGateway, 31 + Code: "server_error", 32 + Description: description, 33 + } 34 + } 35 + 36 + func writeAPIError(w http.ResponseWriter, err error) { 37 + if apiErr, ok := err.(*apiError); ok { 38 + writeJSONError(w, apiErr.Status, apiErr.Code, apiErr.Description) 39 + return 40 + } 41 + 42 + writeJSONError(w, http.StatusInternalServerError, "server_error", "internal server error") 43 + } 44 + 45 + func writeJSONError(w http.ResponseWriter, status int, code string, description string) { 46 + w.Header().Set("Content-Type", "application/json") 47 + w.WriteHeader(status) 48 + _ = json.NewEncoder(w).Encode(map[string]string{ 49 + "error": code, 50 + "error_description": description, 51 + }) 52 + }
+29 -18
handler_par.go
··· 5 5 "log" 6 6 "net/http" 7 7 "net/url" 8 - 9 - "github.com/lestrrat-go/jwx/v2/jwk" 10 8 ) 11 9 12 10 type parRequest struct { 13 11 PAREndpoint string `json:"par_endpoint"` 14 12 Issuer string `json:"issuer"` 13 + KeyID string `json:"key_id,omitempty"` 15 14 LoginHint string `json:"login_hint,omitempty"` 16 15 Scope string `json:"scope"` 17 16 CodeChallenge string `json:"code_challenge"` ··· 20 19 RedirectURI string `json:"redirect_uri"` 21 20 } 22 21 23 - func HandlePAR(signingKey jwk.Key, clientID string) http.HandlerFunc { 22 + func HandlePAR(signers *SignerSet, clientID string) http.HandlerFunc { 24 23 return func(w http.ResponseWriter, r *http.Request) { 25 24 var req parRequest 26 25 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 27 - http.Error(w, `{"error":"invalid_request","error_description":"invalid JSON body"}`, http.StatusBadRequest) 26 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") 28 27 return 29 28 } 30 29 31 30 if req.PAREndpoint == "" { 32 - http.Error(w, `{"error":"invalid_request","error_description":"par_endpoint is required"}`, http.StatusBadRequest) 31 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "par_endpoint is required") 33 32 return 34 33 } 35 - 36 34 if req.Issuer == "" { 37 - http.Error(w, `{"error":"invalid_request","error_description":"issuer is required"}`, http.StatusBadRequest) 35 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "issuer is required") 38 36 return 39 37 } 40 38 41 - if err := ValidateTokenEndpoint(req.PAREndpoint); err != nil { 42 - http.Error(w, `{"error":"invalid_request","error_description":"invalid par_endpoint"}`, http.StatusBadRequest) 39 + if err := ValidatePAREndpointForIssuer(r.Context(), req.Issuer, req.PAREndpoint); err != nil { 40 + writeAPIError(w, err) 41 + return 42 + } 43 + 44 + candidateKeyIDs, err := signers.CandidateKeyIDs(req.KeyID) 45 + if err != nil { 46 + writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error()) 47 + return 48 + } 49 + 50 + signer, err := signers.Lookup(candidateKeyIDs[0]) 51 + if err != nil { 52 + writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error()) 43 53 return 44 54 } 45 55 46 - assertion, err := GenerateClientAssertion(signingKey, clientID, req.Issuer) 56 + assertion, err := GenerateClientAssertion(signer, clientID, req.Issuer) 47 57 if err != nil { 48 58 log.Printf("failed to generate client assertion: %v", err) 49 - http.Error(w, `{"error":"server_error","error_description":"failed to generate client assertion"}`, http.StatusInternalServerError) 59 + writeJSONError(w, http.StatusInternalServerError, "server_error", "failed to generate client assertion") 50 60 return 51 61 } 52 62 ··· 60 70 params.Set("code_challenge_method", req.CodeChallengeMethod) 61 71 params.Set("state", req.State) 62 72 params.Set("redirect_uri", req.RedirectURI) 63 - 64 73 if req.LoginHint != "" { 65 74 params.Set("login_hint", req.LoginHint) 66 75 } 67 76 68 - dpopHeader := r.Header.Get("DPoP") 69 - 70 - started, err := ProxyRequest(w, req.PAREndpoint, params, dpopHeader) 77 + proxied, err := PostForm(r.Context(), req.PAREndpoint, params, r.Header.Get("DPoP")) 71 78 if err != nil { 72 79 log.Printf("proxy request failed: %v", err) 73 - if !started { 74 - http.Error(w, `{"error":"server_error","error_description":"upstream request failed"}`, http.StatusBadGateway) 75 - } 80 + writeAPIError(w, upstreamRequestError("upstream request failed")) 81 + return 82 + } 83 + 84 + w.Header().Set(authProxyKeyIDHeader, candidateKeyIDs[0]) 85 + if err := WriteProxiedResponse(w, proxied); err != nil { 86 + log.Printf("failed to write proxied response: %v", err) 76 87 } 77 88 } 78 89 }
+74 -21
handler_token.go
··· 5 5 "log" 6 6 "net/http" 7 7 "net/url" 8 - 9 - "github.com/lestrrat-go/jwx/v2/jwk" 10 8 ) 11 9 12 10 type tokenRequest struct { 13 11 TokenEndpoint string `json:"token_endpoint"` 14 12 Issuer string `json:"issuer"` 13 + KeyID string `json:"key_id,omitempty"` 15 14 GrantType string `json:"grant_type"` 16 15 Code string `json:"code,omitempty"` 17 16 RedirectURI string `json:"redirect_uri,omitempty"` ··· 19 18 RefreshToken string `json:"refresh_token,omitempty"` 20 19 } 21 20 22 - func HandleToken(signingKey jwk.Key, clientID string) http.HandlerFunc { 21 + func HandleToken(signers *SignerSet, clientID string) http.HandlerFunc { 23 22 return func(w http.ResponseWriter, r *http.Request) { 24 23 var req tokenRequest 25 24 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 26 - http.Error(w, `{"error":"invalid_request","error_description":"invalid JSON body"}`, http.StatusBadRequest) 25 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") 27 26 return 28 27 } 29 28 30 29 if req.TokenEndpoint == "" { 31 - http.Error(w, `{"error":"invalid_request","error_description":"token_endpoint is required"}`, http.StatusBadRequest) 30 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "token_endpoint is required") 32 31 return 33 32 } 34 - 35 33 if req.Issuer == "" { 36 - http.Error(w, `{"error":"invalid_request","error_description":"issuer is required"}`, http.StatusBadRequest) 34 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "issuer is required") 37 35 return 38 36 } 39 - 40 37 if req.GrantType == "" { 41 - http.Error(w, `{"error":"invalid_request","error_description":"grant_type is required"}`, http.StatusBadRequest) 38 + writeJSONError(w, http.StatusBadRequest, "invalid_request", "grant_type is required") 42 39 return 43 40 } 44 41 45 - if err := ValidateTokenEndpoint(req.TokenEndpoint); err != nil { 46 - http.Error(w, `{"error":"invalid_request","error_description":"invalid token_endpoint"}`, http.StatusBadRequest) 42 + if err := ValidateTokenEndpointForIssuer(r.Context(), req.Issuer, req.TokenEndpoint); err != nil { 43 + writeAPIError(w, err) 47 44 return 48 45 } 49 46 50 - assertion, err := GenerateClientAssertion(signingKey, clientID, req.Issuer) 47 + candidateKeyIDs, err := signers.CandidateKeyIDs(req.KeyID) 51 48 if err != nil { 52 - log.Printf("failed to generate client assertion: %v", err) 53 - http.Error(w, `{"error":"server_error","error_description":"failed to generate client assertion"}`, http.StatusInternalServerError) 49 + writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error()) 54 50 return 55 51 } 56 52 57 53 params := url.Values{} 58 54 params.Set("grant_type", req.GrantType) 59 55 params.Set("client_id", clientID) 60 - params.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 61 - params.Set("client_assertion", assertion) 62 56 63 57 if req.Code != "" { 64 58 params.Set("code", req.Code) ··· 75 69 76 70 dpopHeader := r.Header.Get("DPoP") 77 71 78 - started, err := ProxyRequest(w, req.TokenEndpoint, params, dpopHeader) 79 - if err != nil { 80 - log.Printf("proxy request failed: %v", err) 81 - if !started { 82 - http.Error(w, `{"error":"server_error","error_description":"upstream request failed"}`, http.StatusBadGateway) 72 + var proxied *upstreamResponse 73 + var usedKeyID string 74 + 75 + for i, keyID := range candidateKeyIDs { 76 + signer, err := signers.Lookup(keyID) 77 + if err != nil { 78 + writeJSONError(w, http.StatusBadRequest, "invalid_request", err.Error()) 79 + return 80 + } 81 + 82 + assertion, err := GenerateClientAssertion(signer, clientID, req.Issuer) 83 + if err != nil { 84 + log.Printf("failed to generate client assertion: %v", err) 85 + writeJSONError(w, http.StatusInternalServerError, "server_error", "failed to generate client assertion") 86 + return 87 + } 88 + 89 + attemptParams := cloneValues(params) 90 + attemptParams.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 91 + attemptParams.Set("client_assertion", assertion) 92 + 93 + proxied, err = PostForm(r.Context(), req.TokenEndpoint, attemptParams, dpopHeader) 94 + if err != nil { 95 + log.Printf("proxy request failed: %v", err) 96 + writeAPIError(w, upstreamRequestError("upstream request failed")) 97 + return 83 98 } 99 + 100 + usedKeyID = keyID 101 + if req.KeyID == "" && i < len(candidateKeyIDs)-1 && isInvalidClientResponse(proxied) { 102 + continue 103 + } 104 + 105 + break 106 + } 107 + 108 + w.Header().Set(authProxyKeyIDHeader, usedKeyID) 109 + if err := WriteProxiedResponse(w, proxied); err != nil { 110 + log.Printf("failed to write proxied response: %v", err) 84 111 } 85 112 } 86 113 } 114 + 115 + func isInvalidClientResponse(resp *upstreamResponse) bool { 116 + if resp == nil { 117 + return false 118 + } 119 + if resp.statusCode != http.StatusBadRequest && resp.statusCode != http.StatusUnauthorized { 120 + return false 121 + } 122 + 123 + var payload struct { 124 + Error string `json:"error"` 125 + } 126 + if err := json.Unmarshal(resp.body, &payload); err != nil { 127 + return false 128 + } 129 + 130 + return payload.Error == "invalid_client" 131 + } 132 + 133 + func cloneValues(values url.Values) url.Values { 134 + cloned := make(url.Values, len(values)) 135 + for key, entries := range values { 136 + cloned[key] = append([]string(nil), entries...) 137 + } 138 + return cloned 139 + }
+4 -4
main.go
··· 44 44 log.Fatalf("failed to build JWKS: %v", err) 45 45 } 46 46 47 - signingKey, err := NewSigner(privateKey, cfg.KeyID) 47 + signers, err := NewSignerSet(keys, cfg.KeyID) 48 48 if err != nil { 49 - log.Fatalf("failed to create signer: %v", err) 49 + log.Fatalf("failed to create signer set: %v", err) 50 50 } 51 51 52 52 rl := newRateLimiter(cfg.RateLimitPerIP, cfg.RateLimitGlobal) ··· 54 54 mux := http.NewServeMux() 55 55 56 56 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 57 - mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, HandleToken(signingKey, cfg.ClientID))) 58 - mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, HandlePAR(signingKey, cfg.ClientID))) 57 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, cfg.TrustProxyHeaders, HandleToken(signers, cfg.ClientID))) 58 + mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, cfg.TrustProxyHeaders, HandlePAR(signers, cfg.ClientID))) 59 59 mux.HandleFunc("GET /health", HandleHealth) 60 60 61 61 handler := CORSMiddleware(cfg.AllowedOrigins, mux)
+322 -118
main_test.go
··· 1 1 package main 2 2 3 3 import ( 4 + "encoding/base64" 4 5 "encoding/json" 5 6 "io" 6 7 "net/http" ··· 10 11 "testing" 11 12 ) 12 13 14 + const testClientID = "https://example.com/oauth/client-metadata.json" 15 + 16 + type testAuthServerConfig struct { 17 + tokenEndpoint string 18 + parEndpoint string 19 + tokenHandler http.HandlerFunc 20 + parHandler http.HandlerFunc 21 + } 22 + 13 23 func setupTestServer(t *testing.T) (*httptest.Server, func()) { 14 24 t.Helper() 15 - return setupTestServerWithRateLimit(t, 0, 0) 25 + return setupTestServerWithConfig(t, 0, 0, false, nil, "") 16 26 } 17 27 18 28 func setupTestServerWithRateLimit(t *testing.T, perIP, global int) (*httptest.Server, func()) { 19 29 t.Helper() 30 + return setupTestServerWithConfig(t, perIP, global, false, nil, "") 31 + } 20 32 21 - pemData := generateTestPEM(t) 22 - key, err := ParsePrivateKey(pemData) 23 - if err != nil { 24 - t.Fatalf("failed to parse key: %v", err) 33 + func setupTestServerWithConfig(t *testing.T, perIP, global int, trustProxyHeaders bool, keys []keyEntry, activeKeyID string) (*httptest.Server, func()) { 34 + t.Helper() 35 + clearAuthServerMetadataCache() 36 + 37 + if len(keys) == 0 { 38 + pemData := generateTestPEM(t) 39 + key, err := ParsePrivateKey(pemData) 40 + if err != nil { 41 + t.Fatalf("failed to parse key: %v", err) 42 + } 43 + keys = []keyEntry{{privateKey: key, kid: "test-kid"}} 44 + activeKeyID = "test-kid" 25 45 } 26 46 27 - jwksJSON, err := BuildJWKS([]keyEntry{{privateKey: key, kid: "test-kid"}}) 47 + jwksJSON, err := BuildJWKS(keys) 28 48 if err != nil { 29 49 t.Fatalf("failed to build JWKS: %v", err) 30 50 } 31 51 32 - signingKey, err := NewSigner(key, "test-kid") 52 + signers, err := NewSignerSet(keys, activeKeyID) 33 53 if err != nil { 34 - t.Fatalf("failed to create signer: %v", err) 54 + t.Fatalf("failed to create signer set: %v", err) 35 55 } 36 56 37 - clientID := "https://example.com/oauth/client-metadata.json" 38 57 rl := newRateLimiter(perIP, global) 39 58 40 59 mux := http.NewServeMux() 41 60 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 42 - mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, HandleToken(signingKey, clientID))) 43 - mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, HandlePAR(signingKey, clientID))) 61 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, trustProxyHeaders, HandleToken(signers, testClientID))) 62 + mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, trustProxyHeaders, HandlePAR(signers, testClientID))) 44 63 mux.HandleFunc("GET /health", HandleHealth) 45 64 46 65 handler := CORSMiddleware("*", mux) ··· 49 68 return srv, func() { srv.Close() } 50 69 } 51 70 71 + func useTestHTTPClients(client *http.Client) func() { 72 + oldUpstreamClient := upstreamClient 73 + oldMetadataClient := metadataClient 74 + upstreamClient = client 75 + metadataClient = client 76 + 77 + return func() { 78 + upstreamClient = oldUpstreamClient 79 + metadataClient = oldMetadataClient 80 + } 81 + } 82 + 83 + func newTestAuthServer(t *testing.T, cfg testAuthServerConfig) *httptest.Server { 84 + t.Helper() 85 + 86 + var srv *httptest.Server 87 + mux := http.NewServeMux() 88 + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { 89 + tokenEndpoint := cfg.tokenEndpoint 90 + if tokenEndpoint == "" { 91 + tokenEndpoint = srv.URL + "/oauth/token" 92 + } 93 + 94 + parEndpoint := cfg.parEndpoint 95 + if parEndpoint == "" { 96 + parEndpoint = srv.URL + "/oauth/par" 97 + } 98 + 99 + w.Header().Set("Content-Type", "application/json") 100 + _ = json.NewEncoder(w).Encode(authServerMetadata{ 101 + Issuer: srv.URL, 102 + TokenEndpoint: tokenEndpoint, 103 + TokenEndpointAuthMethodsSupported: []string{"none", "private_key_jwt"}, 104 + TokenEndpointAuthSigningAlgsSupported: []string{"ES256"}, 105 + RequirePushedAuthorizationRequests: true, 106 + PushedAuthorizationRequestEndpoint: parEndpoint, 107 + }) 108 + }) 109 + mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { 110 + if cfg.tokenHandler == nil { 111 + http.NotFound(w, r) 112 + return 113 + } 114 + cfg.tokenHandler(w, r) 115 + }) 116 + mux.HandleFunc("/oauth/par", func(w http.ResponseWriter, r *http.Request) { 117 + if cfg.parHandler == nil { 118 + http.NotFound(w, r) 119 + return 120 + } 121 + cfg.parHandler(w, r) 122 + }) 123 + 124 + srv = httptest.NewTLSServer(mux) 125 + u, err := url.Parse(srv.URL) 126 + if err != nil { 127 + t.Fatalf("failed to parse test auth server URL: %v", err) 128 + } 129 + allowHost(u.Hostname()) 130 + return srv 131 + } 132 + 133 + func clientAssertionKeyID(t *testing.T, assertion string) string { 134 + t.Helper() 135 + 136 + parts := strings.Split(assertion, ".") 137 + if len(parts) < 2 { 138 + t.Fatalf("invalid JWT: %q", assertion) 139 + } 140 + 141 + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 142 + if err != nil { 143 + t.Fatalf("failed to decode JWT header: %v", err) 144 + } 145 + 146 + var header struct { 147 + KeyID string `json:"kid"` 148 + } 149 + if err := json.Unmarshal(headerBytes, &header); err != nil { 150 + t.Fatalf("failed to unmarshal JWT header: %v", err) 151 + } 152 + 153 + return header.KeyID 154 + } 155 + 52 156 func TestHealthEndpoint(t *testing.T) { 53 157 srv, cleanup := setupTestServer(t) 54 158 defer cleanup() ··· 86 190 if resp.StatusCode != http.StatusOK { 87 191 t.Errorf("expected 200, got %d", resp.StatusCode) 88 192 } 89 - 90 193 if ct := resp.Header.Get("Content-Type"); ct != "application/json" { 91 194 t.Errorf("expected Content-Type application/json, got %s", ct) 92 195 } 93 - 94 196 if cc := resp.Header.Get("Cache-Control"); cc != "public, max-age=3600" { 95 197 t.Errorf("expected Cache-Control public, max-age=3600, got %s", cc) 96 198 } ··· 160 262 } 161 263 } 162 264 163 - func allowTestHost(t *testing.T, serverURL string) { 164 - t.Helper() 165 - u, err := url.Parse(serverURL) 265 + func TestTokenEndpoint_RejectsIssuerEndpointMismatch(t *testing.T) { 266 + authServer := newTestAuthServer(t, testAuthServerConfig{}) 267 + defer authServer.Close() 268 + defer useTestHTTPClients(authServer.Client())() 269 + 270 + srv, cleanup := setupTestServer(t) 271 + defer cleanup() 272 + 273 + body := `{"token_endpoint":"` + authServer.URL + `/oauth/par","issuer":"` + authServer.URL + `","grant_type":"authorization_code","code":"test"}` 274 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 166 275 if err != nil { 167 - t.Fatalf("failed to parse test server URL: %v", err) 276 + t.Fatalf("request failed: %v", err) 168 277 } 169 - validatedHosts.Store(u.Hostname(), true) 278 + defer resp.Body.Close() 279 + 280 + if resp.StatusCode != http.StatusBadRequest { 281 + t.Errorf("expected 400, got %d", resp.StatusCode) 282 + } 170 283 } 171 284 172 285 func TestTokenEndpoint_ProxiesWithAssertion(t *testing.T) { 173 - // Create a mock upstream auth server 174 - var receivedParams url.Values 286 + var receivedParams map[string]string 175 287 var receivedDPoP string 176 - upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 177 - if err := r.ParseForm(); err != nil { 178 - http.Error(w, "bad form", 400) 179 - return 180 - } 181 - receivedParams = r.Form 182 - receivedDPoP = r.Header.Get("DPoP") 183 - w.Header().Set("DPoP-Nonce", "test-nonce-123") 184 - w.Header().Set("Content-Type", "application/json") 185 - w.WriteHeader(http.StatusOK) 186 - w.Write([]byte(`{"access_token":"at_test","token_type":"DPoP","expires_in":300}`)) 187 - })) 188 - defer upstream.Close() 189 - allowTestHost(t, upstream.URL) 190 288 191 - // Use the TLS test server's client for proxying 192 - oldClient := upstreamClient 193 - upstreamClient = upstream.Client() 194 - defer func() { upstreamClient = oldClient }() 289 + authServer := newTestAuthServer(t, testAuthServerConfig{ 290 + tokenHandler: func(w http.ResponseWriter, r *http.Request) { 291 + if err := r.ParseForm(); err != nil { 292 + http.Error(w, "bad form", http.StatusBadRequest) 293 + return 294 + } 295 + 296 + receivedParams = map[string]string{ 297 + "client_assertion": r.Form.Get("client_assertion"), 298 + "client_assertion_type": r.Form.Get("client_assertion_type"), 299 + "client_id": r.Form.Get("client_id"), 300 + "grant_type": r.Form.Get("grant_type"), 301 + "code": r.Form.Get("code"), 302 + "redirect_uri": r.Form.Get("redirect_uri"), 303 + "code_verifier": r.Form.Get("code_verifier"), 304 + } 305 + receivedDPoP = r.Header.Get("DPoP") 306 + w.Header().Set("DPoP-Nonce", "test-nonce-123") 307 + w.Header().Set("Content-Type", "application/json") 308 + w.WriteHeader(http.StatusOK) 309 + _, _ = w.Write([]byte(`{"access_token":"at_test","token_type":"DPoP","expires_in":300}`)) 310 + }, 311 + }) 312 + defer authServer.Close() 313 + defer useTestHTTPClients(authServer.Client())() 195 314 196 315 srv, cleanup := setupTestServer(t) 197 316 defer cleanup() 198 317 199 318 body := `{ 200 - "token_endpoint":"` + upstream.URL + `/oauth/token", 201 - "issuer":"https://bsky.social", 319 + "token_endpoint":"` + authServer.URL + `/oauth/token", 320 + "issuer":"` + authServer.URL + `", 202 321 "grant_type":"authorization_code", 203 322 "code":"test-auth-code", 204 323 "redirect_uri":"myapp://callback", 205 324 "code_verifier":"test-verifier" 206 325 }` 207 326 208 - req, err := http.NewRequest("POST", srv.URL+"/oauth/token", strings.NewReader(body)) 327 + req, err := http.NewRequest(http.MethodPost, srv.URL+"/oauth/token", strings.NewReader(body)) 209 328 if err != nil { 210 329 t.Fatalf("failed to create request: %v", err) 211 330 } ··· 222 341 respBody, _ := io.ReadAll(resp.Body) 223 342 t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody)) 224 343 } 344 + if resp.Header.Get(authProxyKeyIDHeader) != "test-kid" { 345 + t.Errorf("expected %s=test-kid, got %q", authProxyKeyIDHeader, resp.Header.Get(authProxyKeyIDHeader)) 346 + } 225 347 226 - // Verify client_assertion was added 227 - if receivedParams.Get("client_assertion") == "" { 348 + if receivedParams["client_assertion"] == "" { 228 349 t.Error("expected client_assertion to be added") 229 350 } 230 - if receivedParams.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 231 - t.Errorf("unexpected client_assertion_type: %s", receivedParams.Get("client_assertion_type")) 351 + if receivedParams["client_assertion_type"] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 352 + t.Errorf("unexpected client_assertion_type: %s", receivedParams["client_assertion_type"]) 232 353 } 233 - if receivedParams.Get("client_id") != "https://example.com/oauth/client-metadata.json" { 234 - t.Errorf("unexpected client_id: %s", receivedParams.Get("client_id")) 354 + if receivedParams["client_id"] != testClientID { 355 + t.Errorf("unexpected client_id: %s", receivedParams["client_id"]) 235 356 } 236 - if receivedParams.Get("grant_type") != "authorization_code" { 237 - t.Errorf("unexpected grant_type: %s", receivedParams.Get("grant_type")) 357 + if receivedParams["grant_type"] != "authorization_code" { 358 + t.Errorf("unexpected grant_type: %s", receivedParams["grant_type"]) 238 359 } 239 - if receivedParams.Get("code") != "test-auth-code" { 240 - t.Errorf("unexpected code: %s", receivedParams.Get("code")) 360 + if receivedParams["code"] != "test-auth-code" { 361 + t.Errorf("unexpected code: %s", receivedParams["code"]) 241 362 } 242 - if receivedParams.Get("redirect_uri") != "myapp://callback" { 243 - t.Errorf("unexpected redirect_uri: %s", receivedParams.Get("redirect_uri")) 363 + if receivedParams["redirect_uri"] != "myapp://callback" { 364 + t.Errorf("unexpected redirect_uri: %s", receivedParams["redirect_uri"]) 244 365 } 245 - if receivedParams.Get("code_verifier") != "test-verifier" { 246 - t.Errorf("unexpected code_verifier: %s", receivedParams.Get("code_verifier")) 366 + if receivedParams["code_verifier"] != "test-verifier" { 367 + t.Errorf("unexpected code_verifier: %s", receivedParams["code_verifier"]) 247 368 } 248 - 249 - // Verify DPoP was forwarded 369 + if clientAssertionKeyID(t, receivedParams["client_assertion"]) != "test-kid" { 370 + t.Errorf("expected active signing key in client assertion") 371 + } 250 372 if receivedDPoP != "test-dpop-proof" { 251 373 t.Errorf("expected DPoP header to be forwarded, got %q", receivedDPoP) 252 374 } 253 - 254 - // Verify DPoP-Nonce header was proxied back 255 375 if resp.Header.Get("DPoP-Nonce") != "test-nonce-123" { 256 376 t.Errorf("expected DPoP-Nonce header, got %q", resp.Header.Get("DPoP-Nonce")) 257 377 } 258 378 259 - // Verify response body was proxied 260 379 var tokenResp map[string]interface{} 261 380 if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { 262 381 t.Fatalf("failed to decode response: %v", err) ··· 266 385 } 267 386 } 268 387 269 - func TestTokenEndpoint_UpstreamErrorProxied(t *testing.T) { 270 - upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 271 - w.Header().Set("Content-Type", "application/json") 272 - w.WriteHeader(http.StatusBadRequest) 273 - w.Write([]byte(`{"error":"invalid_grant","error_description":"auth code expired"}`)) 274 - })) 275 - defer upstream.Close() 276 - allowTestHost(t, upstream.URL) 388 + func TestTokenEndpoint_FallsBackToOldKeyOnInvalidClient(t *testing.T) { 389 + newPEM := generateTestPEM(t) 390 + newKey, err := ParsePrivateKey(newPEM) 391 + if err != nil { 392 + t.Fatalf("failed to parse new key: %v", err) 393 + } 394 + oldPEM := generateTestPEM(t) 395 + oldKey, err := ParsePrivateKey(oldPEM) 396 + if err != nil { 397 + t.Fatalf("failed to parse old key: %v", err) 398 + } 277 399 278 - oldClient := upstreamClient 279 - upstreamClient = upstream.Client() 280 - defer func() { upstreamClient = oldClient }() 400 + var attemptedKids []string 401 + authServer := newTestAuthServer(t, testAuthServerConfig{ 402 + tokenHandler: func(w http.ResponseWriter, r *http.Request) { 403 + if err := r.ParseForm(); err != nil { 404 + http.Error(w, "bad form", http.StatusBadRequest) 405 + return 406 + } 407 + 408 + kid := clientAssertionKeyID(t, r.Form.Get("client_assertion")) 409 + attemptedKids = append(attemptedKids, kid) 410 + if kid == "new-kid" { 411 + w.Header().Set("Content-Type", "application/json") 412 + w.WriteHeader(http.StatusBadRequest) 413 + _, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"session bound to old key"}`)) 414 + return 415 + } 416 + 417 + w.Header().Set("Content-Type", "application/json") 418 + w.WriteHeader(http.StatusOK) 419 + _, _ = w.Write([]byte(`{"access_token":"rotated","token_type":"DPoP","expires_in":300}`)) 420 + }, 421 + }) 422 + defer authServer.Close() 423 + defer useTestHTTPClients(authServer.Client())() 424 + 425 + srv, cleanup := setupTestServerWithConfig(t, 0, 0, false, []keyEntry{ 426 + {privateKey: newKey, kid: "new-kid"}, 427 + {privateKey: oldKey, kid: "old-kid"}, 428 + }, "new-kid") 429 + defer cleanup() 430 + 431 + body := `{ 432 + "token_endpoint":"` + authServer.URL + `/oauth/token", 433 + "issuer":"` + authServer.URL + `", 434 + "grant_type":"refresh_token", 435 + "refresh_token":"test-refresh-token" 436 + }` 437 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 438 + if err != nil { 439 + t.Fatalf("request failed: %v", err) 440 + } 441 + defer resp.Body.Close() 442 + 443 + if resp.StatusCode != http.StatusOK { 444 + respBody, _ := io.ReadAll(resp.Body) 445 + t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody)) 446 + } 447 + 448 + if strings.Join(attemptedKids, ",") != "new-kid,old-kid" { 449 + t.Fatalf("expected fallback from new to old key, got %v", attemptedKids) 450 + } 451 + if resp.Header.Get(authProxyKeyIDHeader) != "old-kid" { 452 + t.Errorf("expected %s=old-kid, got %q", authProxyKeyIDHeader, resp.Header.Get(authProxyKeyIDHeader)) 453 + } 454 + } 455 + 456 + func TestTokenEndpoint_UpstreamErrorProxied(t *testing.T) { 457 + authServer := newTestAuthServer(t, testAuthServerConfig{ 458 + tokenHandler: func(w http.ResponseWriter, r *http.Request) { 459 + w.Header().Set("Content-Type", "application/json") 460 + w.WriteHeader(http.StatusBadRequest) 461 + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"auth code expired"}`)) 462 + }, 463 + }) 464 + defer authServer.Close() 465 + defer useTestHTTPClients(authServer.Client())() 281 466 282 467 srv, cleanup := setupTestServer(t) 283 468 defer cleanup() 284 469 285 - body := `{"token_endpoint":"` + upstream.URL + `/oauth/token","issuer":"https://bsky.social","grant_type":"authorization_code","code":"expired-code"}` 470 + body := `{"token_endpoint":"` + authServer.URL + `/oauth/token","issuer":"` + authServer.URL + `","grant_type":"authorization_code","code":"expired-code"}` 286 471 resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 287 472 if err != nil { 288 473 t.Fatalf("request failed: %v", err) ··· 303 488 } 304 489 305 490 func TestTokenEndpoint_UpstreamUnreachable(t *testing.T) { 491 + authServer := newTestAuthServer(t, testAuthServerConfig{ 492 + tokenEndpoint: "https://unreachable-test-host.invalid:19999/oauth/token", 493 + }) 494 + defer authServer.Close() 495 + defer useTestHTTPClients(authServer.Client())() 496 + 306 497 srv, cleanup := setupTestServer(t) 307 498 defer cleanup() 308 499 309 - // Point to a host that won't connect — use a pre-allowed host with a bad port 310 - allowTestHost(t, "https://unreachable-test-host.invalid") 311 - 312 - body := `{"token_endpoint":"https://unreachable-test-host.invalid:19999/oauth/token","issuer":"https://unreachable-test-host.invalid","grant_type":"authorization_code","code":"test"}` 500 + body := `{"token_endpoint":"https://unreachable-test-host.invalid:19999/oauth/token","issuer":"` + authServer.URL + `","grant_type":"authorization_code","code":"test"}` 313 501 resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 314 502 if err != nil { 315 503 t.Fatalf("request failed: %v", err) ··· 351 539 } 352 540 353 541 func TestPAREndpoint_ProxiesWithAssertion(t *testing.T) { 354 - var receivedParams url.Values 542 + var receivedParams map[string]string 355 543 var receivedDPoP string 356 - upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 357 - if err := r.ParseForm(); err != nil { 358 - http.Error(w, "bad form", 400) 359 - return 360 - } 361 - receivedParams = r.Form 362 - receivedDPoP = r.Header.Get("DPoP") 363 - w.Header().Set("Content-Type", "application/json") 364 - w.WriteHeader(http.StatusCreated) 365 - w.Write([]byte(`{"request_uri":"urn:ietf:params:oauth:request_uri:abc123","expires_in":60}`)) 366 - })) 367 - defer upstream.Close() 368 - allowTestHost(t, upstream.URL) 544 + 545 + authServer := newTestAuthServer(t, testAuthServerConfig{ 546 + parHandler: func(w http.ResponseWriter, r *http.Request) { 547 + if err := r.ParseForm(); err != nil { 548 + http.Error(w, "bad form", http.StatusBadRequest) 549 + return 550 + } 369 551 370 - oldClient := upstreamClient 371 - upstreamClient = upstream.Client() 372 - defer func() { upstreamClient = oldClient }() 552 + receivedParams = map[string]string{ 553 + "client_assertion": r.Form.Get("client_assertion"), 554 + "client_assertion_type": r.Form.Get("client_assertion_type"), 555 + "response_type": r.Form.Get("response_type"), 556 + "scope": r.Form.Get("scope"), 557 + "login_hint": r.Form.Get("login_hint"), 558 + "code_challenge": r.Form.Get("code_challenge"), 559 + "state": r.Form.Get("state"), 560 + "redirect_uri": r.Form.Get("redirect_uri"), 561 + } 562 + receivedDPoP = r.Header.Get("DPoP") 563 + w.Header().Set("Content-Type", "application/json") 564 + w.WriteHeader(http.StatusCreated) 565 + _, _ = w.Write([]byte(`{"request_uri":"urn:ietf:params:oauth:request_uri:abc123","expires_in":60}`)) 566 + }, 567 + }) 568 + defer authServer.Close() 569 + defer useTestHTTPClients(authServer.Client())() 373 570 374 571 srv, cleanup := setupTestServer(t) 375 572 defer cleanup() 376 573 377 574 body := `{ 378 - "par_endpoint":"` + upstream.URL + `/oauth/par", 379 - "issuer":"https://bsky.social", 575 + "par_endpoint":"` + authServer.URL + `/oauth/par", 576 + "issuer":"` + authServer.URL + `", 380 577 "login_hint":"user.bsky.social", 381 578 "scope":"atproto transition:generic", 382 579 "code_challenge":"test-challenge", ··· 385 582 "redirect_uri":"myapp://callback" 386 583 }` 387 584 388 - req, err := http.NewRequest("POST", srv.URL+"/oauth/par", strings.NewReader(body)) 585 + req, err := http.NewRequest(http.MethodPost, srv.URL+"/oauth/par", strings.NewReader(body)) 389 586 if err != nil { 390 587 t.Fatalf("failed to create request: %v", err) 391 588 } ··· 403 600 t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(respBody)) 404 601 } 405 602 406 - // Verify client_assertion was added 407 - if receivedParams.Get("client_assertion") == "" { 603 + if receivedParams["client_assertion"] == "" { 408 604 t.Error("expected client_assertion to be added") 409 605 } 410 - if receivedParams.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 411 - t.Errorf("unexpected client_assertion_type: %s", receivedParams.Get("client_assertion_type")) 606 + if receivedParams["client_assertion_type"] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 607 + t.Errorf("unexpected client_assertion_type: %s", receivedParams["client_assertion_type"]) 412 608 } 413 - if receivedParams.Get("response_type") != "code" { 414 - t.Errorf("expected response_type=code, got %s", receivedParams.Get("response_type")) 609 + if receivedParams["response_type"] != "code" { 610 + t.Errorf("expected response_type=code, got %s", receivedParams["response_type"]) 415 611 } 416 - if receivedParams.Get("scope") != "atproto transition:generic" { 417 - t.Errorf("unexpected scope: %s", receivedParams.Get("scope")) 612 + if receivedParams["scope"] != "atproto transition:generic" { 613 + t.Errorf("unexpected scope: %s", receivedParams["scope"]) 418 614 } 419 - if receivedParams.Get("login_hint") != "user.bsky.social" { 420 - t.Errorf("unexpected login_hint: %s", receivedParams.Get("login_hint")) 615 + if receivedParams["login_hint"] != "user.bsky.social" { 616 + t.Errorf("unexpected login_hint: %s", receivedParams["login_hint"]) 421 617 } 422 - if receivedParams.Get("code_challenge") != "test-challenge" { 423 - t.Errorf("unexpected code_challenge: %s", receivedParams.Get("code_challenge")) 618 + if receivedParams["code_challenge"] != "test-challenge" { 619 + t.Errorf("unexpected code_challenge: %s", receivedParams["code_challenge"]) 620 + } 621 + if receivedParams["state"] != "test-state" { 622 + t.Errorf("unexpected state: %s", receivedParams["state"]) 424 623 } 425 - if receivedParams.Get("state") != "test-state" { 426 - t.Errorf("unexpected state: %s", receivedParams.Get("state")) 624 + if receivedParams["redirect_uri"] != "myapp://callback" { 625 + t.Errorf("unexpected redirect_uri: %s", receivedParams["redirect_uri"]) 427 626 } 428 - if receivedParams.Get("redirect_uri") != "myapp://callback" { 429 - t.Errorf("unexpected redirect_uri: %s", receivedParams.Get("redirect_uri")) 627 + if clientAssertionKeyID(t, receivedParams["client_assertion"]) != "test-kid" { 628 + t.Errorf("expected active signing key in client assertion") 430 629 } 431 - 432 - // Verify DPoP was forwarded 433 630 if receivedDPoP != "par-dpop-proof" { 434 631 t.Errorf("expected DPoP header to be forwarded, got %q", receivedDPoP) 435 632 } 633 + if resp.Header.Get(authProxyKeyIDHeader) != "test-kid" { 634 + t.Errorf("expected %s=test-kid, got %q", authProxyKeyIDHeader, resp.Header.Get(authProxyKeyIDHeader)) 635 + } 436 636 437 - // Verify response body was proxied 438 637 var parResp map[string]interface{} 439 638 if err := json.NewDecoder(resp.Body).Decode(&parResp); err != nil { 440 639 t.Fatalf("failed to decode response: %v", err) ··· 457 656 if resp.Header.Get("Access-Control-Allow-Origin") != "*" { 458 657 t.Errorf("expected CORS origin *, got %s", resp.Header.Get("Access-Control-Allow-Origin")) 459 658 } 460 - if resp.Header.Get("Access-Control-Expose-Headers") != "DPoP-Nonce" { 461 - t.Errorf("expected exposed DPoP-Nonce header, got %s", resp.Header.Get("Access-Control-Expose-Headers")) 659 + 660 + exposed := resp.Header.Get("Access-Control-Expose-Headers") 661 + if !strings.Contains(exposed, "DPoP-Nonce") { 662 + t.Errorf("expected exposed DPoP-Nonce header, got %s", exposed) 663 + } 664 + if !strings.Contains(exposed, authProxyKeyIDHeader) { 665 + t.Errorf("expected exposed %s header, got %s", authProxyKeyIDHeader, exposed) 462 666 } 463 667 } 464 668 ··· 466 670 srv, cleanup := setupTestServer(t) 467 671 defer cleanup() 468 672 469 - req, err := http.NewRequest("OPTIONS", srv.URL+"/oauth/token", nil) 673 + req, err := http.NewRequest(http.MethodOptions, srv.URL+"/oauth/token", nil) 470 674 if err != nil { 471 675 t.Fatalf("failed to create request: %v", err) 472 676 }
+1 -1
middleware.go
··· 7 7 w.Header().Set("Access-Control-Allow-Origin", allowedOrigins) 8 8 w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") 9 9 w.Header().Set("Access-Control-Allow-Headers", "Content-Type, DPoP") 10 - w.Header().Set("Access-Control-Expose-Headers", "DPoP-Nonce") 10 + w.Header().Set("Access-Control-Expose-Headers", "DPoP-Nonce, "+authProxyKeyIDHeader) 11 11 12 12 if r.Method == http.MethodOptions { 13 13 w.WriteHeader(http.StatusNoContent)
+32 -29
proxy.go
··· 1 1 package main 2 2 3 3 import ( 4 + "context" 4 5 "fmt" 5 6 "io" 6 7 "net/http" ··· 9 10 "time" 10 11 ) 11 12 12 - var upstreamClient = &http.Client{ 13 - Timeout: 30 * time.Second, 14 - CheckRedirect: func(req *http.Request, via []*http.Request) error { 15 - if len(via) >= 5 { 16 - return fmt.Errorf("too many redirects") 17 - } 18 - host := req.URL.Hostname() 19 - if isPrivateHost(host) { 20 - return fmt.Errorf("redirect to private address blocked") 21 - } 22 - if err := validateResolvedHost(host); err != nil { 23 - return fmt.Errorf("redirect target blocked: %w", err) 24 - } 25 - return nil 26 - }, 13 + type upstreamResponse struct { 14 + statusCode int 15 + headers http.Header 16 + body []byte 27 17 } 28 18 29 - // ProxyRequest forwards a form-encoded POST to the upstream URL. 30 - // Returns (true, err) if the response was already started when the error occurred, 31 - // or (false, err) if no response bytes were written yet. 32 - func ProxyRequest(w http.ResponseWriter, upstreamURL string, formParams url.Values, dpopHeader string) (bool, error) { 33 - req, err := http.NewRequest(http.MethodPost, upstreamURL, strings.NewReader(formParams.Encode())) 19 + var upstreamClient = newPublicHTTPClient(30 * time.Second) 20 + 21 + func PostForm(ctx context.Context, upstreamURL string, formParams url.Values, dpopHeader string) (*upstreamResponse, error) { 22 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, strings.NewReader(formParams.Encode())) 34 23 if err != nil { 35 - return false, fmt.Errorf("failed to create upstream request: %w", err) 24 + return nil, fmt.Errorf("failed to create upstream request: %w", err) 36 25 } 37 26 38 27 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 39 - 40 28 if dpopHeader != "" { 41 29 req.Header.Set("DPoP", dpopHeader) 42 30 } 43 31 44 32 resp, err := upstreamClient.Do(req) 45 33 if err != nil { 46 - return false, fmt.Errorf("upstream request failed: %w", err) 34 + return nil, fmt.Errorf("upstream request failed: %w", err) 47 35 } 48 36 defer resp.Body.Close() 49 37 50 - for key, values := range resp.Header { 38 + body, err := io.ReadAll(io.LimitReader(resp.Body, maxUpstreamResponseBodySize+1)) 39 + if err != nil { 40 + return nil, fmt.Errorf("failed to read upstream response body: %w", err) 41 + } 42 + if len(body) > maxUpstreamResponseBodySize { 43 + return nil, fmt.Errorf("upstream response body exceeded %d bytes", maxUpstreamResponseBodySize) 44 + } 45 + 46 + return &upstreamResponse{ 47 + statusCode: resp.StatusCode, 48 + headers: resp.Header.Clone(), 49 + body: body, 50 + }, nil 51 + } 52 + 53 + func WriteProxiedResponse(w http.ResponseWriter, resp *upstreamResponse) error { 54 + for key, values := range resp.headers { 51 55 for _, value := range values { 52 56 w.Header().Add(key, value) 53 57 } 54 58 } 55 59 56 - w.WriteHeader(resp.StatusCode) 57 - 58 - if _, err := io.Copy(w, resp.Body); err != nil { 59 - return true, fmt.Errorf("failed to copy response body: %w", err) 60 + w.WriteHeader(resp.statusCode) 61 + if _, err := w.Write(resp.body); err != nil { 62 + return fmt.Errorf("failed to write response body: %w", err) 60 63 } 61 64 62 - return true, nil 65 + return nil 63 66 }
+28 -11
ratelimit.go
··· 3 3 import ( 4 4 "net" 5 5 "net/http" 6 + "strings" 6 7 "sync" 7 8 "time" 8 9 ) ··· 71 72 return true 72 73 } 73 74 74 - func RateLimitMiddleware(rl *rateLimiter, next http.HandlerFunc) http.HandlerFunc { 75 + func RateLimitMiddleware(rl *rateLimiter, trustProxyHeaders bool, next http.HandlerFunc) http.HandlerFunc { 75 76 return func(w http.ResponseWriter, r *http.Request) { 76 - ip := clientIP(r) 77 + ip := clientIP(r, trustProxyHeaders) 77 78 if !rl.allow(ip) { 78 79 w.Header().Set("Retry-After", "60") 79 - http.Error(w, `{"error":"too_many_requests","error_description":"rate limit exceeded"}`, http.StatusTooManyRequests) 80 + writeJSONError(w, http.StatusTooManyRequests, "too_many_requests", "rate limit exceeded") 80 81 return 81 82 } 82 83 next(w, r) 83 84 } 84 85 } 85 86 86 - func clientIP(r *http.Request) string { 87 - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { 88 - // First entry is the original client 89 - for i := 0; i < len(xff); i++ { 90 - if xff[i] == ',' { 91 - return xff[:i] 92 - } 87 + func clientIP(r *http.Request, trustProxyHeaders bool) string { 88 + if trustProxyHeaders { 89 + if forwarded := forwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" { 90 + return forwarded 93 91 } 94 - return xff 92 + if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); net.ParseIP(realIP) != nil { 93 + return realIP 94 + } 95 95 } 96 + 96 97 host, _, err := net.SplitHostPort(r.RemoteAddr) 97 98 if err != nil { 98 99 return r.RemoteAddr 99 100 } 100 101 return host 101 102 } 103 + 104 + func forwardedIP(headerValue string) string { 105 + if headerValue == "" { 106 + return "" 107 + } 108 + 109 + first := headerValue 110 + if idx := strings.IndexRune(headerValue, ','); idx >= 0 { 111 + first = headerValue[:idx] 112 + } 113 + first = strings.TrimSpace(first) 114 + if net.ParseIP(first) == nil { 115 + return "" 116 + } 117 + return first 118 + }
+25 -36
ratelimit_test.go
··· 19 19 t.Error("4th request from same IP should be blocked") 20 20 } 21 21 22 - // Different IP should still be allowed 23 22 if !rl.allow("5.6.7.8") { 24 23 t.Error("request from different IP should be allowed") 25 24 } ··· 34 33 } 35 34 } 36 35 37 - // Global limit hit, even from a new IP 38 36 if rl.allow("9.9.9.9") { 39 37 t.Error("request should be blocked by global limit") 40 38 } ··· 43 41 func TestRateLimiter_Combined(t *testing.T) { 44 42 rl := newRateLimiter(2, 5) 45 43 46 - // First IP gets 2 47 - if !rl.allow("1.1.1.1") { 48 - t.Fatal("should be allowed") 49 - } 50 - if !rl.allow("1.1.1.1") { 51 - t.Fatal("should be allowed") 44 + if !rl.allow("1.1.1.1") || !rl.allow("1.1.1.1") { 45 + t.Fatal("first IP should be allowed twice") 52 46 } 53 47 if rl.allow("1.1.1.1") { 54 48 t.Error("per-IP limit should block") 55 49 } 56 50 57 - // Second IP gets 2 more (global at 4 now, 2 counted + 2 more) 58 - // Wait, the per-IP block above didn't increment global. Let me reconsider. 59 - // Global count: 2 (from 1.1.1.1's allowed requests) 60 - if !rl.allow("2.2.2.2") { 61 - t.Fatal("should be allowed") 62 - } 63 - if !rl.allow("2.2.2.2") { 64 - t.Fatal("should be allowed") 51 + if !rl.allow("2.2.2.2") || !rl.allow("2.2.2.2") { 52 + t.Fatal("second IP should be allowed twice") 65 53 } 66 - // Global count: 4. Third IP gets 1 more before global limit. 67 54 if !rl.allow("3.3.3.3") { 68 - t.Fatal("should be allowed") 55 + t.Fatal("third IP should be allowed once") 69 56 } 70 - // Global count: 5. Now global limit hit. 71 57 if rl.allow("4.4.4.4") { 72 58 t.Error("global limit should block") 73 59 } ··· 87 73 srv, cleanup := setupTestServerWithRateLimit(t, 2, 0) 88 74 defer cleanup() 89 75 90 - body := `{"token_endpoint":"https://bsky.social/oauth/token","issuer":"https://bsky.social","grant_type":"authorization_code","code":"test"}` 91 - 92 - // First 2 should succeed (400 from validation, but not 429) 93 76 for i := 0; i < 2; i++ { 94 - resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 77 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(`not json`)) 95 78 if err != nil { 96 79 t.Fatalf("request %d failed: %v", i+1, err) 97 80 } ··· 101 84 } 102 85 } 103 86 104 - // 3rd should be rate limited 105 - resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 87 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(`not json`)) 106 88 if err != nil { 107 89 t.Fatalf("request failed: %v", err) 108 90 } ··· 111 93 if resp.StatusCode != http.StatusTooManyRequests { 112 94 t.Errorf("expected 429, got %d", resp.StatusCode) 113 95 } 114 - 115 96 if resp.Header.Get("Retry-After") != "60" { 116 97 t.Errorf("expected Retry-After: 60, got %s", resp.Header.Get("Retry-After")) 117 98 } ··· 119 100 120 101 func TestClientIP(t *testing.T) { 121 102 tests := []struct { 122 - name string 123 - remoteAddr string 124 - xff string 125 - expected string 103 + name string 104 + remoteAddr string 105 + xff string 106 + xRealIP string 107 + trustProxyHeaders bool 108 + expected string 126 109 }{ 127 - {"remote addr with port", "1.2.3.4:12345", "", "1.2.3.4"}, 128 - {"remote addr without port", "1.2.3.4", "", "1.2.3.4"}, 129 - {"x-forwarded-for single", "9.9.9.9:1234", "1.2.3.4", "1.2.3.4"}, 130 - {"x-forwarded-for multiple", "9.9.9.9:1234", "1.2.3.4, 5.6.7.8", "1.2.3.4"}, 110 + {name: "remote addr with port", remoteAddr: "1.2.3.4:12345", expected: "1.2.3.4"}, 111 + {name: "remote addr without port", remoteAddr: "1.2.3.4", expected: "1.2.3.4"}, 112 + {name: "xff ignored by default", remoteAddr: "9.9.9.9:1234", xff: "1.2.3.4", expected: "9.9.9.9"}, 113 + {name: "xff trusted when enabled", remoteAddr: "9.9.9.9:1234", xff: "1.2.3.4", trustProxyHeaders: true, expected: "1.2.3.4"}, 114 + {name: "xff multiple", remoteAddr: "9.9.9.9:1234", xff: "1.2.3.4, 5.6.7.8", trustProxyHeaders: true, expected: "1.2.3.4"}, 115 + {name: "x-real-ip fallback", remoteAddr: "9.9.9.9:1234", xRealIP: "1.2.3.4", trustProxyHeaders: true, expected: "1.2.3.4"}, 116 + {name: "invalid xff ignored", remoteAddr: "9.9.9.9:1234", xff: "not-an-ip", trustProxyHeaders: true, expected: "9.9.9.9"}, 131 117 } 132 118 133 119 for _, tt := range tests { ··· 139 125 if tt.xff != "" { 140 126 r.Header.Set("X-Forwarded-For", tt.xff) 141 127 } 142 - got := clientIP(r) 143 - if got != tt.expected { 128 + if tt.xRealIP != "" { 129 + r.Header.Set("X-Real-IP", tt.xRealIP) 130 + } 131 + 132 + if got := clientIP(r, tt.trustProxyHeaders); got != tt.expected { 144 133 t.Errorf("clientIP() = %q, want %q", got, tt.expected) 145 134 } 146 135 })
+79
signers.go
··· 1 + package main 2 + 3 + import ( 4 + "fmt" 5 + 6 + "github.com/lestrrat-go/jwx/v2/jwk" 7 + ) 8 + 9 + type SignerSet struct { 10 + activeKeyID string 11 + order []string 12 + signers map[string]jwk.Key 13 + } 14 + 15 + func NewSignerSet(keys []keyEntry, activeKeyID string) (*SignerSet, error) { 16 + if len(keys) == 0 { 17 + return nil, fmt.Errorf("at least one signing key is required") 18 + } 19 + 20 + signers := make(map[string]jwk.Key, len(keys)) 21 + order := make([]string, 0, len(keys)) 22 + 23 + for _, key := range keys { 24 + if key.kid == "" { 25 + return nil, fmt.Errorf("key ID is required for every signing key") 26 + } 27 + if _, exists := signers[key.kid]; exists { 28 + return nil, fmt.Errorf("duplicate signing key ID %q", key.kid) 29 + } 30 + 31 + signer, err := NewSigner(key.privateKey, key.kid) 32 + if err != nil { 33 + return nil, err 34 + } 35 + 36 + signers[key.kid] = signer 37 + order = append(order, key.kid) 38 + } 39 + 40 + if _, ok := signers[activeKeyID]; !ok { 41 + return nil, fmt.Errorf("active key ID %q is not configured", activeKeyID) 42 + } 43 + 44 + if order[0] != activeKeyID { 45 + reordered := []string{activeKeyID} 46 + for _, keyID := range order { 47 + if keyID != activeKeyID { 48 + reordered = append(reordered, keyID) 49 + } 50 + } 51 + order = reordered 52 + } 53 + 54 + return &SignerSet{ 55 + activeKeyID: activeKeyID, 56 + order: order, 57 + signers: signers, 58 + }, nil 59 + } 60 + 61 + func (s *SignerSet) Lookup(keyID string) (jwk.Key, error) { 62 + signer, ok := s.signers[keyID] 63 + if !ok { 64 + return nil, fmt.Errorf("unknown key_id %q", keyID) 65 + } 66 + 67 + return signer, nil 68 + } 69 + 70 + func (s *SignerSet) CandidateKeyIDs(requestedKeyID string) ([]string, error) { 71 + if requestedKeyID != "" { 72 + if _, ok := s.signers[requestedKeyID]; !ok { 73 + return nil, fmt.Errorf("unknown key_id %q", requestedKeyID) 74 + } 75 + return []string{requestedKeyID}, nil 76 + } 77 + 78 + return append([]string(nil), s.order...), nil 79 + }
+185 -58
validation.go
··· 1 1 package main 2 2 3 3 import ( 4 + "context" 4 5 "fmt" 5 6 "net" 7 + "net/http" 6 8 "net/url" 9 + "strconv" 10 + "strings" 7 11 "sync" 12 + "time" 8 13 ) 9 14 10 - var ( 11 - validatedHosts sync.Map 15 + var reservedIPv4Nets = []net.IPNet{ 16 + ipv4Net(0, 0, 0, 0, 8), 17 + ipv4Net(10, 0, 0, 0, 8), 18 + ipv4Net(100, 64, 0, 0, 10), 19 + ipv4Net(127, 0, 0, 0, 8), 20 + ipv4Net(169, 254, 0, 0, 16), 21 + ipv4Net(172, 16, 0, 0, 12), 22 + ipv4Net(192, 0, 0, 0, 24), 23 + ipv4Net(192, 0, 2, 0, 24), 24 + ipv4Net(192, 88, 99, 0, 24), 25 + ipv4Net(192, 168, 0, 0, 16), 26 + ipv4Net(198, 18, 0, 0, 15), 27 + ipv4Net(198, 51, 100, 0, 24), 28 + ipv4Net(203, 0, 113, 0, 24), 29 + ipv4Net(224, 0, 0, 0, 4), 30 + ipv4Net(240, 0, 0, 0, 4), 31 + } 32 + 33 + var globalUnicastIPv6Net = net.IPNet{ 34 + IP: net.IP{0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 35 + Mask: net.CIDRMask(3, 128), 36 + } 37 + 38 + var allowedHosts sync.Map 39 + 40 + func allowHost(host string) { 41 + allowedHosts.Store(strings.TrimSpace(strings.ToLower(host)), true) 42 + } 12 43 13 - privateRanges = []struct { 14 - network *net.IPNet 15 - }{ 16 - {mustParseCIDR("10.0.0.0/8")}, 17 - {mustParseCIDR("172.16.0.0/12")}, 18 - {mustParseCIDR("192.168.0.0/16")}, 19 - {mustParseCIDR("127.0.0.0/8")}, 20 - {mustParseCIDR("169.254.0.0/16")}, 21 - {mustParseCIDR("::1/128")}, 22 - {mustParseCIDR("fc00::/7")}, 23 - {mustParseCIDR("fe80::/10")}, 24 - } 25 - ) 44 + func clearAllowedHosts() { 45 + allowedHosts = sync.Map{} 46 + } 26 47 27 48 func ValidateTokenEndpoint(endpoint string) error { 28 - u, err := url.Parse(endpoint) 49 + _, err := validateEndpointURL(endpoint) 50 + return err 51 + } 52 + 53 + func ValidatePAREndpoint(endpoint string) error { 54 + _, err := validateEndpointURL(endpoint) 55 + return err 56 + } 57 + 58 + func ValidateIssuer(issuer string) (*url.URL, error) { 59 + return validatePublicURL(issuer, true) 60 + } 61 + 62 + func validateEndpointURL(rawURL string) (*url.URL, error) { 63 + return validatePublicURL(rawURL, false) 64 + } 65 + 66 + func validatePublicURL(rawURL string, requireOrigin bool) (*url.URL, error) { 67 + u, err := url.Parse(rawURL) 29 68 if err != nil { 30 - return fmt.Errorf("invalid URL: %w", err) 69 + return nil, fmt.Errorf("invalid URL: %w", err) 31 70 } 32 71 33 72 if u.Scheme != "https" { 34 - return fmt.Errorf("endpoint must use HTTPS") 73 + return nil, fmt.Errorf("URL must use HTTPS") 74 + } 75 + if u.User != nil { 76 + return nil, fmt.Errorf("URL must not include userinfo") 77 + } 78 + if u.Hostname() == "" { 79 + return nil, fmt.Errorf("URL must have a hostname") 80 + } 81 + if !isAllowedHost(u.Hostname()) && isBlockedHostname(u.Hostname()) { 82 + return nil, fmt.Errorf("URL must not target localhost") 83 + } 84 + if u.RawQuery != "" { 85 + return nil, fmt.Errorf("URL must not include a query string") 86 + } 87 + if u.Fragment != "" { 88 + return nil, fmt.Errorf("URL must not include a fragment") 35 89 } 36 90 37 - host := u.Hostname() 38 - if host == "" { 39 - return fmt.Errorf("endpoint must have a hostname") 91 + if requireOrigin { 92 + if u.Path != "" && u.Path != "/" { 93 + return nil, fmt.Errorf("issuer must be an origin URL without a path") 94 + } 95 + if u.Port() == "443" { 96 + return nil, fmt.Errorf("issuer must not include the default HTTPS port") 97 + } 40 98 } 41 99 42 - if _, ok := validatedHosts.Load(host); ok { 43 - return nil 100 + if ip := net.ParseIP(u.Hostname()); ip != nil && !IsPublicIPAddress(ip) && !isAllowedHost(u.Hostname()) { 101 + return nil, fmt.Errorf("URL must not target a private or reserved IP address") 44 102 } 45 103 46 - if isPrivateHost(host) { 47 - return fmt.Errorf("endpoint must not be a private/localhost address") 104 + return u, nil 105 + } 106 + 107 + func newPublicHTTPClient(timeout time.Duration) *http.Client { 108 + return &http.Client{ 109 + Timeout: timeout, 110 + Transport: newPublicOnlyTransport(), 111 + CheckRedirect: func(req *http.Request, via []*http.Request) error { 112 + if len(via) >= 5 { 113 + return fmt.Errorf("too many redirects") 114 + } 115 + _, err := validateEndpointURL(req.URL.String()) 116 + return err 117 + }, 48 118 } 119 + } 49 120 50 - if err := validateResolvedHost(host); err != nil { 51 - return err 121 + func newPublicOnlyTransport() *http.Transport { 122 + transport := http.DefaultTransport.(*http.Transport).Clone() 123 + dialer := &net.Dialer{ 124 + Timeout: 10 * time.Second, 125 + KeepAlive: 30 * time.Second, 52 126 } 53 127 54 - validatedHosts.Store(host, true) 55 - return nil 128 + transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { 129 + return dialPublicContext(ctx, dialer, network, address) 130 + } 131 + transport.ForceAttemptHTTP2 = true 132 + transport.MaxIdleConns = 100 133 + transport.IdleConnTimeout = 90 * time.Second 134 + transport.TLSHandshakeTimeout = 10 * time.Second 135 + transport.ResponseHeaderTimeout = 10 * time.Second 136 + transport.ExpectContinueTimeout = time.Second 137 + return transport 56 138 } 57 139 58 - // validateResolvedHost resolves a hostname via DNS and checks that none of 59 - // the resolved IPs are private. This prevents DNS rebinding attacks where 60 - // a public hostname resolves to a private IP. 61 - func validateResolvedHost(host string) error { 62 - // If it's already a literal IP, isPrivateHost handled it 63 - if net.ParseIP(host) != nil { 64 - return nil 140 + func dialPublicContext(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) { 141 + host, port, err := net.SplitHostPort(address) 142 + if err != nil { 143 + return nil, fmt.Errorf("invalid address %q: %w", address, err) 144 + } 145 + 146 + if !isAllowedHost(host) && isBlockedHostname(host) { 147 + return nil, fmt.Errorf("blocked host %q", host) 148 + } 149 + 150 + portNum, err := strconv.Atoi(port) 151 + if err != nil || portNum < 1 || portNum > 65535 { 152 + return nil, fmt.Errorf("invalid port %q", port) 153 + } 154 + 155 + if ip := net.ParseIP(host); ip != nil { 156 + if !IsPublicIPAddress(ip) && !isAllowedHost(host) { 157 + return nil, fmt.Errorf("blocked IP address %s", ip) 158 + } 159 + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) 65 160 } 66 161 67 - ips, err := net.LookupHost(host) 162 + resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host) 68 163 if err != nil { 69 - return fmt.Errorf("failed to resolve host %q: %w", host, err) 164 + return nil, fmt.Errorf("failed to resolve host %q: %w", host, err) 165 + } 166 + if len(resolved) == 0 { 167 + return nil, fmt.Errorf("host %q did not resolve to any IPs", host) 168 + } 169 + 170 + addresses := make([]string, 0, len(resolved)) 171 + for _, addr := range resolved { 172 + if !IsPublicIPAddress(addr.IP) && !isAllowedHost(host) { 173 + return nil, fmt.Errorf("host %q resolves to non-public IP %s", host, addr.IP) 174 + } 175 + addresses = append(addresses, net.JoinHostPort(addr.IP.String(), port)) 70 176 } 71 177 72 - for _, ipStr := range ips { 73 - ip := net.ParseIP(ipStr) 74 - if ip == nil { 75 - continue 178 + var lastErr error 179 + for _, dialAddress := range addresses { 180 + conn, err := dialer.DialContext(ctx, network, dialAddress) 181 + if err == nil { 182 + return conn, nil 76 183 } 77 - if isPrivateIP(ip) { 78 - return fmt.Errorf("host %q resolves to private address %s", host, ipStr) 184 + lastErr = err 185 + } 186 + 187 + return nil, fmt.Errorf("failed to connect to %q: %w", host, lastErr) 188 + } 189 + 190 + func ipv4Net(a, b, c, d byte, subnetPrefixLen int) net.IPNet { 191 + return net.IPNet{ 192 + IP: net.IPv4(a, b, c, d), 193 + Mask: net.CIDRMask(96+subnetPrefixLen, 128), 194 + } 195 + } 196 + 197 + func IsPublicIPAddress(ip net.IP) bool { 198 + if ip4 := ip.To4(); ip4 != nil { 199 + for _, reserved := range reservedIPv4Nets { 200 + if reserved.Contains(ip4) { 201 + return false 202 + } 79 203 } 204 + return true 80 205 } 81 206 82 - return nil 207 + return globalUnicastIPv6Net.Contains(ip) 83 208 } 84 209 85 210 func isPrivateHost(host string) bool { 86 - if host == "localhost" { 211 + if isBlockedHostname(host) { 87 212 return true 88 213 } 89 214 ··· 92 217 return false 93 218 } 94 219 95 - return isPrivateIP(ip) 220 + return !IsPublicIPAddress(ip) 96 221 } 97 222 98 223 func isPrivateIP(ip net.IP) bool { 99 - for _, r := range privateRanges { 100 - if r.network.Contains(ip) { 101 - return true 102 - } 103 - } 104 - return false 224 + return !IsPublicIPAddress(ip) 105 225 } 106 226 107 - func mustParseCIDR(cidr string) *net.IPNet { 108 - _, network, err := net.ParseCIDR(cidr) 109 - if err != nil { 110 - panic(fmt.Sprintf("invalid CIDR: %s", cidr)) 227 + func isBlockedHostname(host string) bool { 228 + host = strings.TrimSpace(strings.ToLower(host)) 229 + return host == "localhost" || strings.HasSuffix(host, ".localhost") 230 + } 231 + 232 + func isAllowedHost(host string) bool { 233 + host = strings.TrimSpace(strings.ToLower(host)) 234 + if host == "" { 235 + return false 111 236 } 112 - return network 237 + 238 + _, ok := allowedHosts.Load(host) 239 + return ok 113 240 }
+62 -126
validation_test.go
··· 5 5 "testing" 6 6 ) 7 7 8 - func clearValidatedHosts() { 9 - validatedHosts.Range(func(key, value interface{}) bool { 10 - validatedHosts.Delete(key) 11 - return true 12 - }) 13 - } 14 - 15 8 func TestValidateTokenEndpoint(t *testing.T) { 16 - clearValidatedHosts() 17 - 18 - t.Run("valid HTTPS URL", func(t *testing.T) { 19 - if err := ValidateTokenEndpoint("https://bsky.social/oauth/token"); err != nil { 20 - t.Errorf("unexpected error: %v", err) 21 - } 22 - }) 23 - 24 - t.Run("HTTP rejected", func(t *testing.T) { 25 - if err := ValidateTokenEndpoint("http://bsky.social/oauth/token"); err == nil { 26 - t.Error("expected error for HTTP URL") 27 - } 28 - }) 29 - 30 - t.Run("empty URL", func(t *testing.T) { 31 - if err := ValidateTokenEndpoint(""); err == nil { 32 - t.Error("expected error for empty URL") 33 - } 34 - }) 35 - 36 - t.Run("localhost rejected", func(t *testing.T) { 37 - if err := ValidateTokenEndpoint("https://localhost/oauth/token"); err == nil { 38 - t.Error("expected error for localhost") 39 - } 40 - }) 41 - 42 - t.Run("127.0.0.1 rejected", func(t *testing.T) { 43 - if err := ValidateTokenEndpoint("https://127.0.0.1/oauth/token"); err == nil { 44 - t.Error("expected error for 127.0.0.1") 45 - } 46 - }) 47 - 48 - t.Run("10.x.x.x rejected", func(t *testing.T) { 49 - if err := ValidateTokenEndpoint("https://10.0.0.1/oauth/token"); err == nil { 50 - t.Error("expected error for 10.x.x.x") 51 - } 52 - }) 53 - 54 - t.Run("192.168.x.x rejected", func(t *testing.T) { 55 - if err := ValidateTokenEndpoint("https://192.168.1.1/oauth/token"); err == nil { 56 - t.Error("expected error for 192.168.x.x") 57 - } 58 - }) 59 - 60 - t.Run("172.16.x.x rejected", func(t *testing.T) { 61 - if err := ValidateTokenEndpoint("https://172.16.0.1/oauth/token"); err == nil { 62 - t.Error("expected error for 172.16.x.x") 63 - } 64 - }) 9 + clearAllowedHosts() 65 10 66 - t.Run("169.254.x.x rejected", func(t *testing.T) { 67 - if err := ValidateTokenEndpoint("https://169.254.169.254/oauth/token"); err == nil { 68 - t.Error("expected error for 169.254.x.x (link-local/cloud metadata)") 69 - } 70 - }) 71 - 72 - t.Run("IPv6 loopback rejected", func(t *testing.T) { 73 - if err := ValidateTokenEndpoint("https://[::1]/oauth/token"); err == nil { 74 - t.Error("expected error for ::1") 75 - } 76 - }) 77 - 78 - t.Run("IPv6 link-local rejected", func(t *testing.T) { 79 - if err := ValidateTokenEndpoint("https://[fe80::1]/oauth/token"); err == nil { 80 - t.Error("expected error for fe80::1") 81 - } 82 - }) 83 - 84 - t.Run("cached host succeeds", func(t *testing.T) { 85 - // Pre-populate the cache to simulate a previously validated host 86 - validatedHosts.Store("cached-test-host.example.com", true) 87 - host := "https://cached-test-host.example.com/oauth/token" 88 - if err := ValidateTokenEndpoint(host); err != nil { 89 - t.Errorf("cached call failed: %v", err) 90 - } 91 - }) 11 + tests := []struct { 12 + name string 13 + rawURL string 14 + wantErr bool 15 + }{ 16 + {name: "valid HTTPS URL", rawURL: "https://bsky.social/oauth/token"}, 17 + {name: "HTTP rejected", rawURL: "http://bsky.social/oauth/token", wantErr: true}, 18 + {name: "empty URL rejected", rawURL: "", wantErr: true}, 19 + {name: "localhost rejected", rawURL: "https://localhost/oauth/token", wantErr: true}, 20 + {name: "loopback rejected", rawURL: "https://127.0.0.1/oauth/token", wantErr: true}, 21 + {name: "private IPv4 rejected", rawURL: "https://10.0.0.1/oauth/token", wantErr: true}, 22 + {name: "cloud metadata rejected", rawURL: "https://169.254.169.254/oauth/token", wantErr: true}, 23 + {name: "IPv6 loopback rejected", rawURL: "https://[::1]/oauth/token", wantErr: true}, 24 + {name: "query rejected", rawURL: "https://bsky.social/oauth/token?x=1", wantErr: true}, 25 + {name: "fragment rejected", rawURL: "https://bsky.social/oauth/token#frag", wantErr: true}, 26 + {name: "no scheme rejected", rawURL: "bsky.social/oauth/token", wantErr: true}, 27 + } 92 28 93 - t.Run("no scheme rejected", func(t *testing.T) { 94 - if err := ValidateTokenEndpoint("bsky.social/oauth/token"); err == nil { 95 - t.Error("expected error for URL without scheme") 96 - } 97 - }) 29 + for _, tt := range tests { 30 + t.Run(tt.name, func(t *testing.T) { 31 + err := ValidateTokenEndpoint(tt.rawURL) 32 + if (err != nil) != tt.wantErr { 33 + t.Fatalf("ValidateTokenEndpoint(%q) error = %v, wantErr=%v", tt.rawURL, err, tt.wantErr) 34 + } 35 + }) 36 + } 98 37 } 99 38 100 - func TestIsPrivateHost(t *testing.T) { 39 + func TestValidateIssuer(t *testing.T) { 40 + clearAllowedHosts() 41 + 101 42 tests := []struct { 102 - host string 103 - private bool 43 + name string 44 + issuer string 45 + wantErr bool 104 46 }{ 105 - {"localhost", true}, 106 - {"127.0.0.1", true}, 107 - {"127.0.0.2", true}, 108 - {"10.0.0.1", true}, 109 - {"10.255.255.255", true}, 110 - {"172.16.0.1", true}, 111 - {"172.31.255.255", true}, 112 - {"192.168.0.1", true}, 113 - {"192.168.255.255", true}, 114 - {"169.254.169.254", true}, 115 - {"169.254.0.1", true}, 116 - {"::1", true}, 117 - {"fc00::1", true}, 118 - {"fe80::1", true}, 119 - {"8.8.8.8", false}, 120 - {"1.1.1.1", false}, 121 - {"bsky.social", false}, 122 - {"example.com", false}, 123 - {"172.32.0.1", false}, // just outside 172.16.0.0/12 47 + {name: "valid issuer", issuer: "https://bsky.social"}, 48 + {name: "valid issuer with trailing slash", issuer: "https://bsky.social/"}, 49 + {name: "issuer path rejected", issuer: "https://bsky.social/oauth", wantErr: true}, 50 + {name: "issuer query rejected", issuer: "https://bsky.social?x=1", wantErr: true}, 51 + {name: "issuer fragment rejected", issuer: "https://bsky.social#frag", wantErr: true}, 52 + {name: "issuer localhost rejected", issuer: "https://localhost", wantErr: true}, 53 + {name: "issuer default port rejected", issuer: "https://bsky.social:443", wantErr: true}, 124 54 } 125 55 126 56 for _, tt := range tests { 127 - t.Run(tt.host, func(t *testing.T) { 128 - result := isPrivateHost(tt.host) 129 - if result != tt.private { 130 - t.Errorf("isPrivateHost(%q) = %v, want %v", tt.host, result, tt.private) 57 + t.Run(tt.name, func(t *testing.T) { 58 + _, err := ValidateIssuer(tt.issuer) 59 + if (err != nil) != tt.wantErr { 60 + t.Fatalf("ValidateIssuer(%q) error = %v, wantErr=%v", tt.issuer, err, tt.wantErr) 131 61 } 132 62 }) 133 63 } 134 64 } 135 65 136 - func TestIsPrivateIP(t *testing.T) { 66 + func TestIsPublicIPAddress(t *testing.T) { 137 67 tests := []struct { 138 - ipStr string 139 - private bool 68 + ipStr string 69 + isPublic bool 140 70 }{ 141 - {"169.254.169.254", true}, 142 - {"169.254.0.1", true}, 143 - {"fe80::1", true}, 144 - {"fe80::abcd", true}, 145 - {"2001:db8::1", false}, 146 - {"8.8.4.4", false}, 71 + {ipStr: "8.8.8.8", isPublic: true}, 72 + {ipStr: "1.1.1.1", isPublic: true}, 73 + {ipStr: "10.0.0.1", isPublic: false}, 74 + {ipStr: "100.64.0.1", isPublic: false}, 75 + {ipStr: "127.0.0.1", isPublic: false}, 76 + {ipStr: "169.254.169.254", isPublic: false}, 77 + {ipStr: "192.0.2.10", isPublic: false}, 78 + {ipStr: "203.0.113.10", isPublic: false}, 79 + {ipStr: "::1", isPublic: false}, 80 + {ipStr: "fe80::1", isPublic: false}, 81 + {ipStr: "fc00::1", isPublic: false}, 82 + {ipStr: "2606:4700:4700::1111", isPublic: true}, 147 83 } 148 84 149 85 for _, tt := range tests { 150 86 t.Run(tt.ipStr, func(t *testing.T) { 151 87 ip := net.ParseIP(tt.ipStr) 152 88 if ip == nil { 153 - t.Fatalf("failed to parse IP: %s", tt.ipStr) 89 + t.Fatalf("failed to parse IP %q", tt.ipStr) 154 90 } 155 - result := isPrivateIP(ip) 156 - if result != tt.private { 157 - t.Errorf("isPrivateIP(%q) = %v, want %v", tt.ipStr, result, tt.private) 91 + 92 + if got := IsPublicIPAddress(ip); got != tt.isPublic { 93 + t.Fatalf("IsPublicIPAddress(%q) = %v, want %v", tt.ipStr, got, tt.isPublic) 158 94 } 159 95 }) 160 96 }