Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

Fix aud claim to use issuer per spec, harden SSRF validation with DNS resolution, and return 502 on upstream failures

+186 -39
+3 -1
README.md
··· 124 124 125 125 ## Security Considerations 126 126 127 - - **Token endpoint validation**: The proxy validates that upstream URLs use HTTPS and rejects private/localhost addresses to prevent SSRF 127 + - **Token endpoint validation**: The proxy validates that upstream URLs use HTTPS, resolves hostnames via DNS to reject private addresses, and rejects private/localhost/link-local addresses to prevent SSRF 128 + - **Redirect protection**: Upstream HTTP redirects are validated to prevent redirection to private addresses 129 + - **Request timeout**: Upstream requests have a 30-second timeout to prevent slow-loris attacks 128 130 - **No token logging**: Token values, auth codes, and refresh tokens are never logged 129 131 - **HTTPS required**: The proxy must be served over HTTPS in production (handled automatically by Railway/Fly.io) 130 132 - **DPoP passthrough**: The proxy never sees DPoP private keys — proofs are between the device and auth server
+2 -2
assertion_test.go
··· 28 28 func TestGenerateClientAssertion(t *testing.T) { 29 29 signer, key := testSignerAndKey(t) 30 30 clientID := "https://example.com/oauth/client-metadata.json" 31 - audience := "https://bsky.social/oauth/token" 31 + audience := "https://bsky.social" 32 32 33 33 assertion, err := GenerateClientAssertion(signer, clientID, audience) 34 34 if err != nil { ··· 81 81 func TestGenerateClientAssertion_UniqueJTI(t *testing.T) { 82 82 signer, _ := testSignerAndKey(t) 83 83 clientID := "https://example.com/oauth/client-metadata.json" 84 - audience := "https://bsky.social/oauth/token" 84 + audience := "https://bsky.social" 85 85 86 86 a1, err := GenerateClientAssertion(signer, clientID, audience) 87 87 if err != nil {
+12 -2
handler_par.go
··· 11 11 12 12 type parRequest struct { 13 13 PAREndpoint string `json:"par_endpoint"` 14 + Issuer string `json:"issuer"` 14 15 LoginHint string `json:"login_hint,omitempty"` 15 16 Scope string `json:"scope"` 16 17 CodeChallenge string `json:"code_challenge"` ··· 32 33 return 33 34 } 34 35 36 + if req.Issuer == "" { 37 + http.Error(w, `{"error":"invalid_request","error_description":"issuer is required"}`, http.StatusBadRequest) 38 + return 39 + } 40 + 35 41 if err := ValidateTokenEndpoint(req.PAREndpoint); err != nil { 36 42 http.Error(w, `{"error":"invalid_request","error_description":"invalid par_endpoint"}`, http.StatusBadRequest) 37 43 return 38 44 } 39 45 40 - assertion, err := GenerateClientAssertion(signingKey, clientID, req.PAREndpoint) 46 + assertion, err := GenerateClientAssertion(signingKey, clientID, req.Issuer) 41 47 if err != nil { 42 48 log.Printf("failed to generate client assertion: %v", err) 43 49 http.Error(w, `{"error":"server_error","error_description":"failed to generate client assertion"}`, http.StatusInternalServerError) ··· 61 67 62 68 dpopHeader := r.Header.Get("DPoP") 63 69 64 - if err := ProxyRequest(w, req.PAREndpoint, params, dpopHeader); err != nil { 70 + started, err := ProxyRequest(w, req.PAREndpoint, params, dpopHeader) 71 + if err != nil { 65 72 log.Printf("proxy request failed: %v", err) 73 + if !started { 74 + http.Error(w, `{"error":"server_error","error_description":"upstream request failed"}`, http.StatusBadGateway) 75 + } 66 76 } 67 77 } 68 78 }
+12 -2
handler_token.go
··· 11 11 12 12 type tokenRequest struct { 13 13 TokenEndpoint string `json:"token_endpoint"` 14 + Issuer string `json:"issuer"` 14 15 GrantType string `json:"grant_type"` 15 16 Code string `json:"code,omitempty"` 16 17 RedirectURI string `json:"redirect_uri,omitempty"` ··· 31 32 return 32 33 } 33 34 35 + if req.Issuer == "" { 36 + http.Error(w, `{"error":"invalid_request","error_description":"issuer is required"}`, http.StatusBadRequest) 37 + return 38 + } 39 + 34 40 if req.GrantType == "" { 35 41 http.Error(w, `{"error":"invalid_request","error_description":"grant_type is required"}`, http.StatusBadRequest) 36 42 return ··· 41 47 return 42 48 } 43 49 44 - assertion, err := GenerateClientAssertion(signingKey, clientID, req.TokenEndpoint) 50 + assertion, err := GenerateClientAssertion(signingKey, clientID, req.Issuer) 45 51 if err != nil { 46 52 log.Printf("failed to generate client assertion: %v", err) 47 53 http.Error(w, `{"error":"server_error","error_description":"failed to generate client assertion"}`, http.StatusInternalServerError) ··· 69 75 70 76 dpopHeader := r.Header.Get("DPoP") 71 77 72 - if err := ProxyRequest(w, req.TokenEndpoint, params, dpopHeader); err != nil { 78 + started, err := ProxyRequest(w, req.TokenEndpoint, params, dpopHeader) 79 + if err != nil { 73 80 log.Printf("proxy request failed: %v", err) 81 + if !started { 82 + http.Error(w, `{"error":"server_error","error_description":"upstream request failed"}`, http.StatusBadGateway) 83 + } 74 84 } 75 85 } 76 86 }
+38 -11
main_test.go
··· 117 117 name string 118 118 body string 119 119 }{ 120 - {"missing token_endpoint", `{"grant_type":"authorization_code"}`}, 121 - {"missing grant_type", `{"token_endpoint":"https://bsky.social/oauth/token"}`}, 120 + {"missing token_endpoint", `{"issuer":"https://bsky.social","grant_type":"authorization_code"}`}, 121 + {"missing issuer", `{"token_endpoint":"https://bsky.social/oauth/token","grant_type":"authorization_code"}`}, 122 + {"missing grant_type", `{"token_endpoint":"https://bsky.social/oauth/token","issuer":"https://bsky.social"}`}, 122 123 {"invalid JSON", `not json`}, 123 124 } 124 125 ··· 141 142 srv, cleanup := setupTestServer(t) 142 143 defer cleanup() 143 144 144 - body := `{"token_endpoint":"http://bsky.social/oauth/token","grant_type":"authorization_code"}` 145 + body := `{"token_endpoint":"http://bsky.social/oauth/token","issuer":"https://bsky.social","grant_type":"authorization_code"}` 145 146 resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 146 147 if err != nil { 147 148 t.Fatalf("request failed: %v", err) ··· 182 183 allowTestHost(t, upstream.URL) 183 184 184 185 // Use the TLS test server's client for proxying 185 - http.DefaultClient = upstream.Client() 186 - defer func() { http.DefaultClient = &http.Client{} }() 186 + oldClient := upstreamClient 187 + upstreamClient = upstream.Client() 188 + defer func() { upstreamClient = oldClient }() 187 189 188 190 srv, cleanup := setupTestServer(t) 189 191 defer cleanup() 190 192 191 193 body := `{ 192 194 "token_endpoint":"` + upstream.URL + `/oauth/token", 195 + "issuer":"https://bsky.social", 193 196 "grant_type":"authorization_code", 194 197 "code":"test-auth-code", 195 198 "redirect_uri":"myapp://callback", ··· 266 269 defer upstream.Close() 267 270 allowTestHost(t, upstream.URL) 268 271 269 - http.DefaultClient = upstream.Client() 270 - defer func() { http.DefaultClient = &http.Client{} }() 272 + oldClient := upstreamClient 273 + upstreamClient = upstream.Client() 274 + defer func() { upstreamClient = oldClient }() 271 275 272 276 srv, cleanup := setupTestServer(t) 273 277 defer cleanup() 274 278 275 - body := `{"token_endpoint":"` + upstream.URL + `/oauth/token","grant_type":"authorization_code","code":"expired-code"}` 279 + body := `{"token_endpoint":"` + upstream.URL + `/oauth/token","issuer":"https://bsky.social","grant_type":"authorization_code","code":"expired-code"}` 276 280 resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 277 281 if err != nil { 278 282 t.Fatalf("request failed: %v", err) ··· 292 296 } 293 297 } 294 298 299 + func TestTokenEndpoint_UpstreamUnreachable(t *testing.T) { 300 + srv, cleanup := setupTestServer(t) 301 + defer cleanup() 302 + 303 + // Point to a host that won't connect — use a pre-allowed host with a bad port 304 + allowTestHost(t, "https://unreachable-test-host.invalid") 305 + 306 + body := `{"token_endpoint":"https://unreachable-test-host.invalid:19999/oauth/token","issuer":"https://unreachable-test-host.invalid","grant_type":"authorization_code","code":"test"}` 307 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 308 + if err != nil { 309 + t.Fatalf("request failed: %v", err) 310 + } 311 + defer resp.Body.Close() 312 + 313 + if resp.StatusCode != http.StatusBadGateway { 314 + respBody, _ := io.ReadAll(resp.Body) 315 + t.Errorf("expected 502 for unreachable upstream, got %d: %s", resp.StatusCode, string(respBody)) 316 + } 317 + } 318 + 295 319 func TestPAREndpoint_MissingFields(t *testing.T) { 296 320 srv, cleanup := setupTestServer(t) 297 321 defer cleanup() ··· 300 324 name string 301 325 body string 302 326 }{ 303 - {"missing par_endpoint", `{"scope":"atproto"}`}, 327 + {"missing par_endpoint", `{"issuer":"https://bsky.social","scope":"atproto"}`}, 328 + {"missing issuer", `{"par_endpoint":"https://bsky.social/oauth/par","scope":"atproto"}`}, 304 329 {"invalid JSON", `{broken`}, 305 330 } 306 331 ··· 336 361 defer upstream.Close() 337 362 allowTestHost(t, upstream.URL) 338 363 339 - http.DefaultClient = upstream.Client() 340 - defer func() { http.DefaultClient = &http.Client{} }() 364 + oldClient := upstreamClient 365 + upstreamClient = upstream.Client() 366 + defer func() { upstreamClient = oldClient }() 341 367 342 368 srv, cleanup := setupTestServer(t) 343 369 defer cleanup() 344 370 345 371 body := `{ 346 372 "par_endpoint":"` + upstream.URL + `/oauth/par", 373 + "issuer":"https://bsky.social", 347 374 "login_hint":"user.bsky.social", 348 375 "scope":"atproto transition:generic", 349 376 "code_challenge":"test-challenge",
+27 -6
proxy.go
··· 6 6 "net/http" 7 7 "net/url" 8 8 "strings" 9 + "time" 9 10 ) 10 11 11 - func ProxyRequest(w http.ResponseWriter, upstreamURL string, formParams url.Values, dpopHeader string) error { 12 + var upstreamClient = &http.Client{ 13 + Timeout: 30 * time.Second, 14 + CheckRedirect: func(req *http.Request, via []*http.Request) error { 15 + if len(via) >= 5 { 16 + return fmt.Errorf("too many redirects") 17 + } 18 + host := req.URL.Hostname() 19 + if isPrivateHost(host) { 20 + return fmt.Errorf("redirect to private address blocked") 21 + } 22 + if err := validateResolvedHost(host); err != nil { 23 + return fmt.Errorf("redirect target blocked: %w", err) 24 + } 25 + return nil 26 + }, 27 + } 28 + 29 + // ProxyRequest forwards a form-encoded POST to the upstream URL. 30 + // Returns (true, err) if the response was already started when the error occurred, 31 + // or (false, err) if no response bytes were written yet. 32 + func ProxyRequest(w http.ResponseWriter, upstreamURL string, formParams url.Values, dpopHeader string) (bool, error) { 12 33 req, err := http.NewRequest(http.MethodPost, upstreamURL, strings.NewReader(formParams.Encode())) 13 34 if err != nil { 14 - return fmt.Errorf("failed to create upstream request: %w", err) 35 + return false, fmt.Errorf("failed to create upstream request: %w", err) 15 36 } 16 37 17 38 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") ··· 20 41 req.Header.Set("DPoP", dpopHeader) 21 42 } 22 43 23 - resp, err := http.DefaultClient.Do(req) 44 + resp, err := upstreamClient.Do(req) 24 45 if err != nil { 25 - return fmt.Errorf("upstream request failed: %w", err) 46 + return false, fmt.Errorf("upstream request failed: %w", err) 26 47 } 27 48 defer resp.Body.Close() 28 49 ··· 35 56 w.WriteHeader(resp.StatusCode) 36 57 37 58 if _, err := io.Copy(w, resp.Body); err != nil { 38 - return fmt.Errorf("failed to copy response body: %w", err) 59 + return true, fmt.Errorf("failed to copy response body: %w", err) 39 60 } 40 61 41 - return nil 62 + return true, nil 42 63 }
+47 -11
validation.go
··· 9 9 10 10 var ( 11 11 validatedHosts sync.Map 12 + 13 + privateRanges = []struct { 14 + network *net.IPNet 15 + }{ 16 + {mustParseCIDR("10.0.0.0/8")}, 17 + {mustParseCIDR("172.16.0.0/12")}, 18 + {mustParseCIDR("192.168.0.0/16")}, 19 + {mustParseCIDR("127.0.0.0/8")}, 20 + {mustParseCIDR("169.254.0.0/16")}, 21 + {mustParseCIDR("::1/128")}, 22 + {mustParseCIDR("fc00::/7")}, 23 + {mustParseCIDR("fe80::/10")}, 24 + } 12 25 ) 13 26 14 27 func ValidateTokenEndpoint(endpoint string) error { ··· 34 47 return fmt.Errorf("endpoint must not be a private/localhost address") 35 48 } 36 49 50 + if err := validateResolvedHost(host); err != nil { 51 + return err 52 + } 53 + 37 54 validatedHosts.Store(host, true) 38 55 return nil 39 56 } 40 57 58 + // validateResolvedHost resolves a hostname via DNS and checks that none of 59 + // the resolved IPs are private. This prevents DNS rebinding attacks where 60 + // a public hostname resolves to a private IP. 61 + func validateResolvedHost(host string) error { 62 + // If it's already a literal IP, isPrivateHost handled it 63 + if net.ParseIP(host) != nil { 64 + return nil 65 + } 66 + 67 + ips, err := net.LookupHost(host) 68 + if err != nil { 69 + return fmt.Errorf("failed to resolve host %q: %w", host, err) 70 + } 71 + 72 + for _, ipStr := range ips { 73 + ip := net.ParseIP(ipStr) 74 + if ip == nil { 75 + continue 76 + } 77 + if isPrivateIP(ip) { 78 + return fmt.Errorf("host %q resolves to private address %s", host, ipStr) 79 + } 80 + } 81 + 82 + return nil 83 + } 84 + 41 85 func isPrivateHost(host string) bool { 42 86 if host == "localhost" { 43 87 return true ··· 48 92 return false 49 93 } 50 94 51 - privateRanges := []struct { 52 - network *net.IPNet 53 - }{ 54 - {mustParseCIDR("10.0.0.0/8")}, 55 - {mustParseCIDR("172.16.0.0/12")}, 56 - {mustParseCIDR("192.168.0.0/16")}, 57 - {mustParseCIDR("127.0.0.0/8")}, 58 - {mustParseCIDR("::1/128")}, 59 - {mustParseCIDR("fc00::/7")}, 60 - } 95 + return isPrivateIP(ip) 96 + } 61 97 98 + func isPrivateIP(ip net.IP) bool { 62 99 for _, r := range privateRanges { 63 100 if r.network.Contains(ip) { 64 101 return true 65 102 } 66 103 } 67 - 68 104 return false 69 105 } 70 106
+45 -4
validation_test.go
··· 1 1 package main 2 2 3 3 import ( 4 + "net" 4 5 "testing" 5 6 ) 6 7 ··· 62 63 } 63 64 }) 64 65 66 + t.Run("169.254.x.x rejected", func(t *testing.T) { 67 + if err := ValidateTokenEndpoint("https://169.254.169.254/oauth/token"); err == nil { 68 + t.Error("expected error for 169.254.x.x (link-local/cloud metadata)") 69 + } 70 + }) 71 + 65 72 t.Run("IPv6 loopback rejected", func(t *testing.T) { 66 73 if err := ValidateTokenEndpoint("https://[::1]/oauth/token"); err == nil { 67 74 t.Error("expected error for ::1") 75 + } 76 + }) 77 + 78 + t.Run("IPv6 link-local rejected", func(t *testing.T) { 79 + if err := ValidateTokenEndpoint("https://[fe80::1]/oauth/token"); err == nil { 80 + t.Error("expected error for fe80::1") 68 81 } 69 82 }) 70 83 71 84 t.Run("cached host succeeds", func(t *testing.T) { 85 + // Pre-populate the cache to simulate a previously validated host 86 + validatedHosts.Store("cached-test-host.example.com", true) 72 87 host := "https://cached-test-host.example.com/oauth/token" 73 - if err := ValidateTokenEndpoint(host); err != nil { 74 - t.Fatalf("first call failed: %v", err) 75 - } 76 - // Second call should hit cache and succeed 77 88 if err := ValidateTokenEndpoint(host); err != nil { 78 89 t.Errorf("cached call failed: %v", err) 79 90 } ··· 100 111 {"172.31.255.255", true}, 101 112 {"192.168.0.1", true}, 102 113 {"192.168.255.255", true}, 114 + {"169.254.169.254", true}, 115 + {"169.254.0.1", true}, 103 116 {"::1", true}, 104 117 {"fc00::1", true}, 118 + {"fe80::1", true}, 105 119 {"8.8.8.8", false}, 106 120 {"1.1.1.1", false}, 107 121 {"bsky.social", false}, ··· 118 132 }) 119 133 } 120 134 } 135 + 136 + func TestIsPrivateIP(t *testing.T) { 137 + tests := []struct { 138 + ipStr string 139 + private bool 140 + }{ 141 + {"169.254.169.254", true}, 142 + {"169.254.0.1", true}, 143 + {"fe80::1", true}, 144 + {"fe80::abcd", true}, 145 + {"2001:db8::1", false}, 146 + {"8.8.4.4", false}, 147 + } 148 + 149 + for _, tt := range tests { 150 + t.Run(tt.ipStr, func(t *testing.T) { 151 + ip := net.ParseIP(tt.ipStr) 152 + if ip == nil { 153 + t.Fatalf("failed to parse IP: %s", tt.ipStr) 154 + } 155 + result := isPrivateIP(ip) 156 + if result != tt.private { 157 + t.Errorf("isPrivateIP(%q) = %v, want %v", tt.ipStr, result, tt.private) 158 + } 159 + }) 160 + } 161 + }