Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

add tests

+851 -4
+99
assertion_test.go
··· 1 + package main 2 + 3 + import ( 4 + "crypto/ecdsa" 5 + "crypto/elliptic" 6 + "crypto/rand" 7 + "testing" 8 + "time" 9 + 10 + "github.com/lestrrat-go/jwx/v2/jwa" 11 + "github.com/lestrrat-go/jwx/v2/jwk" 12 + "github.com/lestrrat-go/jwx/v2/jwt" 13 + ) 14 + 15 + func testSignerAndKey(t *testing.T) (jwk.Key, *ecdsa.PrivateKey) { 16 + t.Helper() 17 + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 18 + if err != nil { 19 + t.Fatalf("failed to generate key: %v", err) 20 + } 21 + signer, err := NewSigner(key, "test-kid") 22 + if err != nil { 23 + t.Fatalf("failed to create signer: %v", err) 24 + } 25 + return signer, key 26 + } 27 + 28 + func TestGenerateClientAssertion(t *testing.T) { 29 + signer, key := testSignerAndKey(t) 30 + clientID := "https://example.com/oauth/client-metadata.json" 31 + audience := "https://bsky.social/oauth/token" 32 + 33 + assertion, err := GenerateClientAssertion(signer, clientID, audience) 34 + if err != nil { 35 + t.Fatalf("unexpected error: %v", err) 36 + } 37 + 38 + if assertion == "" { 39 + t.Fatal("expected non-empty assertion") 40 + } 41 + 42 + // Parse and verify the JWT using the public key 43 + pubJWK, err := jwk.FromRaw(key.Public()) 44 + if err != nil { 45 + t.Fatalf("failed to create public JWK: %v", err) 46 + } 47 + 48 + parsed, err := jwt.Parse([]byte(assertion), jwt.WithKey(jwa.ES256, pubJWK)) 49 + if err != nil { 50 + t.Fatalf("failed to parse/verify JWT: %v", err) 51 + } 52 + 53 + if parsed.Issuer() != clientID { 54 + t.Errorf("expected iss=%s, got %s", clientID, parsed.Issuer()) 55 + } 56 + if parsed.Subject() != clientID { 57 + t.Errorf("expected sub=%s, got %s", clientID, parsed.Subject()) 58 + } 59 + 60 + audiences := parsed.Audience() 61 + if len(audiences) != 1 || audiences[0] != audience { 62 + t.Errorf("expected aud=[%s], got %v", audience, audiences) 63 + } 64 + 65 + if parsed.JwtID() == "" { 66 + t.Error("expected non-empty jti") 67 + } 68 + 69 + now := time.Now() 70 + if parsed.IssuedAt().After(now) { 71 + t.Error("iat should not be in the future") 72 + } 73 + 74 + expectedExp := parsed.IssuedAt().Add(60 * time.Second) 75 + diff := parsed.Expiration().Sub(expectedExp) 76 + if diff < -time.Second || diff > time.Second { 77 + t.Errorf("expected exp ~60s after iat, got iat=%v exp=%v", parsed.IssuedAt(), parsed.Expiration()) 78 + } 79 + } 80 + 81 + func TestGenerateClientAssertion_UniqueJTI(t *testing.T) { 82 + signer, _ := testSignerAndKey(t) 83 + clientID := "https://example.com/oauth/client-metadata.json" 84 + audience := "https://bsky.social/oauth/token" 85 + 86 + a1, err := GenerateClientAssertion(signer, clientID, audience) 87 + if err != nil { 88 + t.Fatalf("unexpected error: %v", err) 89 + } 90 + 91 + a2, err := GenerateClientAssertion(signer, clientID, audience) 92 + if err != nil { 93 + t.Fatalf("unexpected error: %v", err) 94 + } 95 + 96 + if a1 == a2 { 97 + t.Error("two assertions should have different jti values and therefore differ") 98 + } 99 + }
+173
keys_test.go
··· 1 + package main 2 + 3 + import ( 4 + "crypto/ecdsa" 5 + "crypto/elliptic" 6 + "crypto/rand" 7 + "crypto/x509" 8 + "encoding/json" 9 + "encoding/pem" 10 + "testing" 11 + ) 12 + 13 + func generateTestPEM(t *testing.T) string { 14 + t.Helper() 15 + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 16 + if err != nil { 17 + t.Fatalf("failed to generate test key: %v", err) 18 + } 19 + der, err := x509.MarshalPKCS8PrivateKey(key) 20 + if err != nil { 21 + t.Fatalf("failed to marshal test key: %v", err) 22 + } 23 + block := &pem.Block{Type: "PRIVATE KEY", Bytes: der} 24 + return string(pem.EncodeToMemory(block)) 25 + } 26 + 27 + func TestParsePrivateKey(t *testing.T) { 28 + t.Run("valid P-256 PEM", func(t *testing.T) { 29 + pemData := generateTestPEM(t) 30 + key, err := ParsePrivateKey(pemData) 31 + if err != nil { 32 + t.Fatalf("unexpected error: %v", err) 33 + } 34 + if key.Curve != elliptic.P256() { 35 + t.Fatalf("expected P-256 curve, got %s", key.Curve.Params().Name) 36 + } 37 + }) 38 + 39 + t.Run("invalid PEM", func(t *testing.T) { 40 + _, err := ParsePrivateKey("not a pem") 41 + if err == nil { 42 + t.Fatal("expected error for invalid PEM") 43 + } 44 + }) 45 + 46 + t.Run("empty PEM", func(t *testing.T) { 47 + _, err := ParsePrivateKey("") 48 + if err == nil { 49 + t.Fatal("expected error for empty PEM") 50 + } 51 + }) 52 + 53 + t.Run("wrong key type", func(t *testing.T) { 54 + // Generate a P-384 key (wrong curve) 55 + key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) 56 + if err != nil { 57 + t.Fatalf("failed to generate P-384 key: %v", err) 58 + } 59 + der, err := x509.MarshalPKCS8PrivateKey(key) 60 + if err != nil { 61 + t.Fatalf("failed to marshal key: %v", err) 62 + } 63 + block := &pem.Block{Type: "PRIVATE KEY", Bytes: der} 64 + pemData := string(pem.EncodeToMemory(block)) 65 + 66 + _, err = ParsePrivateKey(pemData) 67 + if err == nil { 68 + t.Fatal("expected error for non-P-256 key") 69 + } 70 + }) 71 + } 72 + 73 + func TestBuildJWKS(t *testing.T) { 74 + pemData := generateTestPEM(t) 75 + key, err := ParsePrivateKey(pemData) 76 + if err != nil { 77 + t.Fatalf("failed to parse key: %v", err) 78 + } 79 + 80 + kid := "test-key-1" 81 + jwksBytes, err := BuildJWKS(key, kid) 82 + if err != nil { 83 + t.Fatalf("unexpected error: %v", err) 84 + } 85 + 86 + var jwks struct { 87 + Keys []struct { 88 + Kty string `json:"kty"` 89 + Crv string `json:"crv"` 90 + Kid string `json:"kid"` 91 + Use string `json:"use"` 92 + Alg string `json:"alg"` 93 + X string `json:"x"` 94 + Y string `json:"y"` 95 + } `json:"keys"` 96 + } 97 + if err := json.Unmarshal(jwksBytes, &jwks); err != nil { 98 + t.Fatalf("failed to unmarshal JWKS: %v", err) 99 + } 100 + 101 + if len(jwks.Keys) != 1 { 102 + t.Fatalf("expected 1 key, got %d", len(jwks.Keys)) 103 + } 104 + 105 + k := jwks.Keys[0] 106 + if k.Kty != "EC" { 107 + t.Errorf("expected kty=EC, got %s", k.Kty) 108 + } 109 + if k.Crv != "P-256" { 110 + t.Errorf("expected crv=P-256, got %s", k.Crv) 111 + } 112 + if k.Kid != kid { 113 + t.Errorf("expected kid=%s, got %s", kid, k.Kid) 114 + } 115 + if k.Use != "sig" { 116 + t.Errorf("expected use=sig, got %s", k.Use) 117 + } 118 + if k.Alg != "ES256" { 119 + t.Errorf("expected alg=ES256, got %s", k.Alg) 120 + } 121 + if k.X == "" || k.Y == "" { 122 + t.Error("expected non-empty x and y coordinates") 123 + } 124 + } 125 + 126 + func TestBuildJWKS_NoPrivateKey(t *testing.T) { 127 + pemData := generateTestPEM(t) 128 + key, err := ParsePrivateKey(pemData) 129 + if err != nil { 130 + t.Fatalf("failed to parse key: %v", err) 131 + } 132 + 133 + jwksBytes, err := BuildJWKS(key, "test-key") 134 + if err != nil { 135 + t.Fatalf("unexpected error: %v", err) 136 + } 137 + 138 + // Verify no private key material (d) is present 139 + var raw map[string]json.RawMessage 140 + if err := json.Unmarshal(jwksBytes, &raw); err != nil { 141 + t.Fatalf("failed to unmarshal: %v", err) 142 + } 143 + 144 + var keys []map[string]json.RawMessage 145 + if err := json.Unmarshal(raw["keys"], &keys); err != nil { 146 + t.Fatalf("failed to unmarshal keys: %v", err) 147 + } 148 + 149 + if _, hasD := keys[0]["d"]; hasD { 150 + t.Fatal("JWKS must not contain private key material (d)") 151 + } 152 + } 153 + 154 + func TestNewSigner(t *testing.T) { 155 + pemData := generateTestPEM(t) 156 + key, err := ParsePrivateKey(pemData) 157 + if err != nil { 158 + t.Fatalf("failed to parse key: %v", err) 159 + } 160 + 161 + kid := "signer-key-1" 162 + signer, err := NewSigner(key, kid) 163 + if err != nil { 164 + t.Fatalf("unexpected error: %v", err) 165 + } 166 + 167 + if signer.KeyID() != kid { 168 + t.Errorf("expected kid=%s, got %s", kid, signer.KeyID()) 169 + } 170 + if signer.Algorithm().String() != "ES256" { 171 + t.Errorf("expected alg=ES256, got %s", signer.Algorithm()) 172 + } 173 + }
+455
main_test.go
··· 1 + package main 2 + 3 + import ( 4 + "encoding/json" 5 + "io" 6 + "net/http" 7 + "net/http/httptest" 8 + "net/url" 9 + "strings" 10 + "testing" 11 + ) 12 + 13 + func setupTestServer(t *testing.T) (*httptest.Server, func()) { 14 + t.Helper() 15 + 16 + pemData := generateTestPEM(t) 17 + key, err := ParsePrivateKey(pemData) 18 + if err != nil { 19 + t.Fatalf("failed to parse key: %v", err) 20 + } 21 + 22 + jwksJSON, err := BuildJWKS(key, "test-kid") 23 + if err != nil { 24 + t.Fatalf("failed to build JWKS: %v", err) 25 + } 26 + 27 + signingKey, err := NewSigner(key, "test-kid") 28 + if err != nil { 29 + t.Fatalf("failed to create signer: %v", err) 30 + } 31 + 32 + clientID := "https://example.com/oauth/client-metadata.json" 33 + 34 + mux := http.NewServeMux() 35 + mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 36 + mux.HandleFunc("POST /oauth/token", HandleToken(signingKey, clientID)) 37 + mux.HandleFunc("POST /oauth/par", HandlePAR(signingKey, clientID)) 38 + mux.HandleFunc("GET /health", HandleHealth) 39 + 40 + handler := CORSMiddleware("*", mux) 41 + srv := httptest.NewServer(handler) 42 + 43 + return srv, func() { srv.Close() } 44 + } 45 + 46 + func TestHealthEndpoint(t *testing.T) { 47 + srv, cleanup := setupTestServer(t) 48 + defer cleanup() 49 + 50 + resp, err := http.Get(srv.URL + "/health") 51 + if err != nil { 52 + t.Fatalf("request failed: %v", err) 53 + } 54 + defer resp.Body.Close() 55 + 56 + if resp.StatusCode != http.StatusOK { 57 + t.Errorf("expected 200, got %d", resp.StatusCode) 58 + } 59 + 60 + var body map[string]string 61 + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { 62 + t.Fatalf("failed to decode body: %v", err) 63 + } 64 + 65 + if body["status"] != "ok" { 66 + t.Errorf("expected status=ok, got %s", body["status"]) 67 + } 68 + } 69 + 70 + func TestJWKSEndpoint(t *testing.T) { 71 + srv, cleanup := setupTestServer(t) 72 + defer cleanup() 73 + 74 + resp, err := http.Get(srv.URL + "/.well-known/jwks.json") 75 + if err != nil { 76 + t.Fatalf("request failed: %v", err) 77 + } 78 + defer resp.Body.Close() 79 + 80 + if resp.StatusCode != http.StatusOK { 81 + t.Errorf("expected 200, got %d", resp.StatusCode) 82 + } 83 + 84 + if ct := resp.Header.Get("Content-Type"); ct != "application/json" { 85 + t.Errorf("expected Content-Type application/json, got %s", ct) 86 + } 87 + 88 + if cc := resp.Header.Get("Cache-Control"); cc != "public, max-age=3600" { 89 + t.Errorf("expected Cache-Control public, max-age=3600, got %s", cc) 90 + } 91 + 92 + var jwks struct { 93 + Keys []map[string]interface{} `json:"keys"` 94 + } 95 + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { 96 + t.Fatalf("failed to decode JWKS: %v", err) 97 + } 98 + 99 + if len(jwks.Keys) != 1 { 100 + t.Fatalf("expected 1 key, got %d", len(jwks.Keys)) 101 + } 102 + 103 + key := jwks.Keys[0] 104 + if key["kty"] != "EC" { 105 + t.Errorf("expected kty=EC, got %v", key["kty"]) 106 + } 107 + if key["kid"] != "test-kid" { 108 + t.Errorf("expected kid=test-kid, got %v", key["kid"]) 109 + } 110 + } 111 + 112 + func TestTokenEndpoint_MissingFields(t *testing.T) { 113 + srv, cleanup := setupTestServer(t) 114 + defer cleanup() 115 + 116 + tests := []struct { 117 + name string 118 + body string 119 + }{ 120 + {"missing token_endpoint", `{"grant_type":"authorization_code"}`}, 121 + {"missing grant_type", `{"token_endpoint":"https://bsky.social/oauth/token"}`}, 122 + {"invalid JSON", `not json`}, 123 + } 124 + 125 + for _, tt := range tests { 126 + t.Run(tt.name, func(t *testing.T) { 127 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(tt.body)) 128 + if err != nil { 129 + t.Fatalf("request failed: %v", err) 130 + } 131 + defer resp.Body.Close() 132 + 133 + if resp.StatusCode != http.StatusBadRequest { 134 + t.Errorf("expected 400, got %d", resp.StatusCode) 135 + } 136 + }) 137 + } 138 + } 139 + 140 + func TestTokenEndpoint_InvalidEndpointURL(t *testing.T) { 141 + srv, cleanup := setupTestServer(t) 142 + defer cleanup() 143 + 144 + body := `{"token_endpoint":"http://bsky.social/oauth/token","grant_type":"authorization_code"}` 145 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 146 + if err != nil { 147 + t.Fatalf("request failed: %v", err) 148 + } 149 + defer resp.Body.Close() 150 + 151 + if resp.StatusCode != http.StatusBadRequest { 152 + t.Errorf("expected 400 for HTTP endpoint, got %d", resp.StatusCode) 153 + } 154 + } 155 + 156 + func allowTestHost(t *testing.T, serverURL string) { 157 + t.Helper() 158 + u, err := url.Parse(serverURL) 159 + if err != nil { 160 + t.Fatalf("failed to parse test server URL: %v", err) 161 + } 162 + validatedHosts.Store(u.Hostname(), true) 163 + } 164 + 165 + func TestTokenEndpoint_ProxiesWithAssertion(t *testing.T) { 166 + // Create a mock upstream auth server 167 + var receivedParams url.Values 168 + var receivedDPoP string 169 + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 + if err := r.ParseForm(); err != nil { 171 + http.Error(w, "bad form", 400) 172 + return 173 + } 174 + receivedParams = r.Form 175 + receivedDPoP = r.Header.Get("DPoP") 176 + w.Header().Set("DPoP-Nonce", "test-nonce-123") 177 + w.Header().Set("Content-Type", "application/json") 178 + w.WriteHeader(http.StatusOK) 179 + w.Write([]byte(`{"access_token":"at_test","token_type":"DPoP","expires_in":300}`)) 180 + })) 181 + defer upstream.Close() 182 + allowTestHost(t, upstream.URL) 183 + 184 + // Use the TLS test server's client for proxying 185 + http.DefaultClient = upstream.Client() 186 + defer func() { http.DefaultClient = &http.Client{} }() 187 + 188 + srv, cleanup := setupTestServer(t) 189 + defer cleanup() 190 + 191 + body := `{ 192 + "token_endpoint":"` + upstream.URL + `/oauth/token", 193 + "grant_type":"authorization_code", 194 + "code":"test-auth-code", 195 + "redirect_uri":"myapp://callback", 196 + "code_verifier":"test-verifier" 197 + }` 198 + 199 + req, err := http.NewRequest("POST", srv.URL+"/oauth/token", strings.NewReader(body)) 200 + if err != nil { 201 + t.Fatalf("failed to create request: %v", err) 202 + } 203 + req.Header.Set("Content-Type", "application/json") 204 + req.Header.Set("DPoP", "test-dpop-proof") 205 + 206 + resp, err := http.DefaultClient.Do(req) 207 + if err != nil { 208 + t.Fatalf("request failed: %v", err) 209 + } 210 + defer resp.Body.Close() 211 + 212 + if resp.StatusCode != http.StatusOK { 213 + respBody, _ := io.ReadAll(resp.Body) 214 + t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody)) 215 + } 216 + 217 + // Verify client_assertion was added 218 + if receivedParams.Get("client_assertion") == "" { 219 + t.Error("expected client_assertion to be added") 220 + } 221 + if receivedParams.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 222 + t.Errorf("unexpected client_assertion_type: %s", receivedParams.Get("client_assertion_type")) 223 + } 224 + if receivedParams.Get("client_id") != "https://example.com/oauth/client-metadata.json" { 225 + t.Errorf("unexpected client_id: %s", receivedParams.Get("client_id")) 226 + } 227 + if receivedParams.Get("grant_type") != "authorization_code" { 228 + t.Errorf("unexpected grant_type: %s", receivedParams.Get("grant_type")) 229 + } 230 + if receivedParams.Get("code") != "test-auth-code" { 231 + t.Errorf("unexpected code: %s", receivedParams.Get("code")) 232 + } 233 + if receivedParams.Get("redirect_uri") != "myapp://callback" { 234 + t.Errorf("unexpected redirect_uri: %s", receivedParams.Get("redirect_uri")) 235 + } 236 + if receivedParams.Get("code_verifier") != "test-verifier" { 237 + t.Errorf("unexpected code_verifier: %s", receivedParams.Get("code_verifier")) 238 + } 239 + 240 + // Verify DPoP was forwarded 241 + if receivedDPoP != "test-dpop-proof" { 242 + t.Errorf("expected DPoP header to be forwarded, got %q", receivedDPoP) 243 + } 244 + 245 + // Verify DPoP-Nonce header was proxied back 246 + if resp.Header.Get("DPoP-Nonce") != "test-nonce-123" { 247 + t.Errorf("expected DPoP-Nonce header, got %q", resp.Header.Get("DPoP-Nonce")) 248 + } 249 + 250 + // Verify response body was proxied 251 + var tokenResp map[string]interface{} 252 + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { 253 + t.Fatalf("failed to decode response: %v", err) 254 + } 255 + if tokenResp["access_token"] != "at_test" { 256 + t.Errorf("unexpected access_token: %v", tokenResp["access_token"]) 257 + } 258 + } 259 + 260 + func TestTokenEndpoint_UpstreamErrorProxied(t *testing.T) { 261 + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 262 + w.Header().Set("Content-Type", "application/json") 263 + w.WriteHeader(http.StatusBadRequest) 264 + w.Write([]byte(`{"error":"invalid_grant","error_description":"auth code expired"}`)) 265 + })) 266 + defer upstream.Close() 267 + allowTestHost(t, upstream.URL) 268 + 269 + http.DefaultClient = upstream.Client() 270 + defer func() { http.DefaultClient = &http.Client{} }() 271 + 272 + srv, cleanup := setupTestServer(t) 273 + defer cleanup() 274 + 275 + body := `{"token_endpoint":"` + upstream.URL + `/oauth/token","grant_type":"authorization_code","code":"expired-code"}` 276 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 277 + if err != nil { 278 + t.Fatalf("request failed: %v", err) 279 + } 280 + defer resp.Body.Close() 281 + 282 + if resp.StatusCode != http.StatusBadRequest { 283 + t.Errorf("expected upstream 400 to be proxied, got %d", resp.StatusCode) 284 + } 285 + 286 + var errResp map[string]string 287 + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { 288 + t.Fatalf("failed to decode error response: %v", err) 289 + } 290 + if errResp["error"] != "invalid_grant" { 291 + t.Errorf("expected error=invalid_grant, got %s", errResp["error"]) 292 + } 293 + } 294 + 295 + func TestPAREndpoint_MissingFields(t *testing.T) { 296 + srv, cleanup := setupTestServer(t) 297 + defer cleanup() 298 + 299 + tests := []struct { 300 + name string 301 + body string 302 + }{ 303 + {"missing par_endpoint", `{"scope":"atproto"}`}, 304 + {"invalid JSON", `{broken`}, 305 + } 306 + 307 + for _, tt := range tests { 308 + t.Run(tt.name, func(t *testing.T) { 309 + resp, err := http.Post(srv.URL+"/oauth/par", "application/json", strings.NewReader(tt.body)) 310 + if err != nil { 311 + t.Fatalf("request failed: %v", err) 312 + } 313 + defer resp.Body.Close() 314 + 315 + if resp.StatusCode != http.StatusBadRequest { 316 + t.Errorf("expected 400, got %d", resp.StatusCode) 317 + } 318 + }) 319 + } 320 + } 321 + 322 + func TestPAREndpoint_ProxiesWithAssertion(t *testing.T) { 323 + var receivedParams url.Values 324 + var receivedDPoP string 325 + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 326 + if err := r.ParseForm(); err != nil { 327 + http.Error(w, "bad form", 400) 328 + return 329 + } 330 + receivedParams = r.Form 331 + receivedDPoP = r.Header.Get("DPoP") 332 + w.Header().Set("Content-Type", "application/json") 333 + w.WriteHeader(http.StatusCreated) 334 + w.Write([]byte(`{"request_uri":"urn:ietf:params:oauth:request_uri:abc123","expires_in":60}`)) 335 + })) 336 + defer upstream.Close() 337 + allowTestHost(t, upstream.URL) 338 + 339 + http.DefaultClient = upstream.Client() 340 + defer func() { http.DefaultClient = &http.Client{} }() 341 + 342 + srv, cleanup := setupTestServer(t) 343 + defer cleanup() 344 + 345 + body := `{ 346 + "par_endpoint":"` + upstream.URL + `/oauth/par", 347 + "login_hint":"user.bsky.social", 348 + "scope":"atproto transition:generic", 349 + "code_challenge":"test-challenge", 350 + "code_challenge_method":"S256", 351 + "state":"test-state", 352 + "redirect_uri":"myapp://callback" 353 + }` 354 + 355 + req, err := http.NewRequest("POST", srv.URL+"/oauth/par", strings.NewReader(body)) 356 + if err != nil { 357 + t.Fatalf("failed to create request: %v", err) 358 + } 359 + req.Header.Set("Content-Type", "application/json") 360 + req.Header.Set("DPoP", "par-dpop-proof") 361 + 362 + resp, err := http.DefaultClient.Do(req) 363 + if err != nil { 364 + t.Fatalf("request failed: %v", err) 365 + } 366 + defer resp.Body.Close() 367 + 368 + if resp.StatusCode != http.StatusCreated { 369 + respBody, _ := io.ReadAll(resp.Body) 370 + t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(respBody)) 371 + } 372 + 373 + // Verify client_assertion was added 374 + if receivedParams.Get("client_assertion") == "" { 375 + t.Error("expected client_assertion to be added") 376 + } 377 + if receivedParams.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 378 + t.Errorf("unexpected client_assertion_type: %s", receivedParams.Get("client_assertion_type")) 379 + } 380 + if receivedParams.Get("response_type") != "code" { 381 + t.Errorf("expected response_type=code, got %s", receivedParams.Get("response_type")) 382 + } 383 + if receivedParams.Get("scope") != "atproto transition:generic" { 384 + t.Errorf("unexpected scope: %s", receivedParams.Get("scope")) 385 + } 386 + if receivedParams.Get("login_hint") != "user.bsky.social" { 387 + t.Errorf("unexpected login_hint: %s", receivedParams.Get("login_hint")) 388 + } 389 + if receivedParams.Get("code_challenge") != "test-challenge" { 390 + t.Errorf("unexpected code_challenge: %s", receivedParams.Get("code_challenge")) 391 + } 392 + if receivedParams.Get("state") != "test-state" { 393 + t.Errorf("unexpected state: %s", receivedParams.Get("state")) 394 + } 395 + if receivedParams.Get("redirect_uri") != "myapp://callback" { 396 + t.Errorf("unexpected redirect_uri: %s", receivedParams.Get("redirect_uri")) 397 + } 398 + 399 + // Verify DPoP was forwarded 400 + if receivedDPoP != "par-dpop-proof" { 401 + t.Errorf("expected DPoP header to be forwarded, got %q", receivedDPoP) 402 + } 403 + 404 + // Verify response body was proxied 405 + var parResp map[string]interface{} 406 + if err := json.NewDecoder(resp.Body).Decode(&parResp); err != nil { 407 + t.Fatalf("failed to decode response: %v", err) 408 + } 409 + if parResp["request_uri"] != "urn:ietf:params:oauth:request_uri:abc123" { 410 + t.Errorf("unexpected request_uri: %v", parResp["request_uri"]) 411 + } 412 + } 413 + 414 + func TestCORSHeaders(t *testing.T) { 415 + srv, cleanup := setupTestServer(t) 416 + defer cleanup() 417 + 418 + resp, err := http.Get(srv.URL + "/health") 419 + if err != nil { 420 + t.Fatalf("request failed: %v", err) 421 + } 422 + defer resp.Body.Close() 423 + 424 + if resp.Header.Get("Access-Control-Allow-Origin") != "*" { 425 + t.Errorf("expected CORS origin *, got %s", resp.Header.Get("Access-Control-Allow-Origin")) 426 + } 427 + if resp.Header.Get("Access-Control-Expose-Headers") != "DPoP-Nonce" { 428 + t.Errorf("expected exposed DPoP-Nonce header, got %s", resp.Header.Get("Access-Control-Expose-Headers")) 429 + } 430 + } 431 + 432 + func TestCORSPreflight(t *testing.T) { 433 + srv, cleanup := setupTestServer(t) 434 + defer cleanup() 435 + 436 + req, err := http.NewRequest("OPTIONS", srv.URL+"/oauth/token", nil) 437 + if err != nil { 438 + t.Fatalf("failed to create request: %v", err) 439 + } 440 + 441 + resp, err := http.DefaultClient.Do(req) 442 + if err != nil { 443 + t.Fatalf("request failed: %v", err) 444 + } 445 + defer resp.Body.Close() 446 + 447 + if resp.StatusCode != http.StatusNoContent { 448 + t.Errorf("expected 204 for OPTIONS, got %d", resp.StatusCode) 449 + } 450 + 451 + allowHeaders := resp.Header.Get("Access-Control-Allow-Headers") 452 + if !strings.Contains(allowHeaders, "DPoP") { 453 + t.Errorf("expected DPoP in allowed headers, got %s", allowHeaders) 454 + } 455 + }
+4 -4
validation.go
··· 26 26 return fmt.Errorf("endpoint must have a hostname") 27 27 } 28 28 29 - if isPrivateHost(host) { 30 - return fmt.Errorf("endpoint must not be a private/localhost address") 31 - } 32 - 33 29 if _, ok := validatedHosts.Load(host); ok { 34 30 return nil 31 + } 32 + 33 + if isPrivateHost(host) { 34 + return fmt.Errorf("endpoint must not be a private/localhost address") 35 35 } 36 36 37 37 validatedHosts.Store(host, true)
+120
validation_test.go
··· 1 + package main 2 + 3 + import ( 4 + "testing" 5 + ) 6 + 7 + func clearValidatedHosts() { 8 + validatedHosts.Range(func(key, value interface{}) bool { 9 + validatedHosts.Delete(key) 10 + return true 11 + }) 12 + } 13 + 14 + func TestValidateTokenEndpoint(t *testing.T) { 15 + clearValidatedHosts() 16 + 17 + t.Run("valid HTTPS URL", func(t *testing.T) { 18 + if err := ValidateTokenEndpoint("https://bsky.social/oauth/token"); err != nil { 19 + t.Errorf("unexpected error: %v", err) 20 + } 21 + }) 22 + 23 + t.Run("HTTP rejected", func(t *testing.T) { 24 + if err := ValidateTokenEndpoint("http://bsky.social/oauth/token"); err == nil { 25 + t.Error("expected error for HTTP URL") 26 + } 27 + }) 28 + 29 + t.Run("empty URL", func(t *testing.T) { 30 + if err := ValidateTokenEndpoint(""); err == nil { 31 + t.Error("expected error for empty URL") 32 + } 33 + }) 34 + 35 + t.Run("localhost rejected", func(t *testing.T) { 36 + if err := ValidateTokenEndpoint("https://localhost/oauth/token"); err == nil { 37 + t.Error("expected error for localhost") 38 + } 39 + }) 40 + 41 + t.Run("127.0.0.1 rejected", func(t *testing.T) { 42 + if err := ValidateTokenEndpoint("https://127.0.0.1/oauth/token"); err == nil { 43 + t.Error("expected error for 127.0.0.1") 44 + } 45 + }) 46 + 47 + t.Run("10.x.x.x rejected", func(t *testing.T) { 48 + if err := ValidateTokenEndpoint("https://10.0.0.1/oauth/token"); err == nil { 49 + t.Error("expected error for 10.x.x.x") 50 + } 51 + }) 52 + 53 + t.Run("192.168.x.x rejected", func(t *testing.T) { 54 + if err := ValidateTokenEndpoint("https://192.168.1.1/oauth/token"); err == nil { 55 + t.Error("expected error for 192.168.x.x") 56 + } 57 + }) 58 + 59 + t.Run("172.16.x.x rejected", func(t *testing.T) { 60 + if err := ValidateTokenEndpoint("https://172.16.0.1/oauth/token"); err == nil { 61 + t.Error("expected error for 172.16.x.x") 62 + } 63 + }) 64 + 65 + t.Run("IPv6 loopback rejected", func(t *testing.T) { 66 + if err := ValidateTokenEndpoint("https://[::1]/oauth/token"); err == nil { 67 + t.Error("expected error for ::1") 68 + } 69 + }) 70 + 71 + t.Run("cached host succeeds", func(t *testing.T) { 72 + host := "https://cached-test-host.example.com/oauth/token" 73 + if err := ValidateTokenEndpoint(host); err != nil { 74 + t.Fatalf("first call failed: %v", err) 75 + } 76 + // Second call should hit cache and succeed 77 + if err := ValidateTokenEndpoint(host); err != nil { 78 + t.Errorf("cached call failed: %v", err) 79 + } 80 + }) 81 + 82 + t.Run("no scheme rejected", func(t *testing.T) { 83 + if err := ValidateTokenEndpoint("bsky.social/oauth/token"); err == nil { 84 + t.Error("expected error for URL without scheme") 85 + } 86 + }) 87 + } 88 + 89 + func TestIsPrivateHost(t *testing.T) { 90 + tests := []struct { 91 + host string 92 + private bool 93 + }{ 94 + {"localhost", true}, 95 + {"127.0.0.1", true}, 96 + {"127.0.0.2", true}, 97 + {"10.0.0.1", true}, 98 + {"10.255.255.255", true}, 99 + {"172.16.0.1", true}, 100 + {"172.31.255.255", true}, 101 + {"192.168.0.1", true}, 102 + {"192.168.255.255", true}, 103 + {"::1", true}, 104 + {"fc00::1", true}, 105 + {"8.8.8.8", false}, 106 + {"1.1.1.1", false}, 107 + {"bsky.social", false}, 108 + {"example.com", false}, 109 + {"172.32.0.1", false}, // just outside 172.16.0.0/12 110 + } 111 + 112 + for _, tt := range tests { 113 + t.Run(tt.host, func(t *testing.T) { 114 + result := isPrivateHost(tt.host) 115 + if result != tt.private { 116 + t.Errorf("isPrivateHost(%q) = %v, want %v", tt.host, result, tt.private) 117 + } 118 + }) 119 + } 120 + }