Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add key rotation support via AUTH_OLD_PRIVATE_KEY and per-IP/global rate limiting on token and PAR endpoints

+411 -42
+15 -6
README.md
··· 44 44 45 45 | Variable | Required | Default | Description | 46 46 |----------|----------|---------|-------------| 47 - | `AUTH_PRIVATE_KEY` | Yes | — | PEM-encoded EC P-256 private key | 47 + | `AUTH_PRIVATE_KEY` | Yes | — | PEM-encoded EC P-256 private key (active signing key) | 48 48 | `AUTH_CLIENT_ID` | Yes | — | Your app's OAuth client_id (client-metadata.json URL) | 49 - | `AUTH_KEY_ID` | No | `atproto-auth-1` | JWKS key identifier (`kid`) | 49 + | `AUTH_KEY_ID` | No | `atproto-auth-1` | JWKS key identifier (`kid`) for the active key | 50 + | `AUTH_OLD_PRIVATE_KEY` | No | — | PEM-encoded EC P-256 private key (previous key, for rotation) | 51 + | `AUTH_OLD_KEY_ID` | No | — | JWKS key identifier (`kid`) for the old key (required with `AUTH_OLD_PRIVATE_KEY`) | 50 52 | `AUTH_BIND` | No | `:8080` | Listen address | 51 53 | `AUTH_ALLOWED_ORIGINS` | No | `*` | CORS allowed origins | 54 + | `AUTH_RATE_LIMIT_PER_IP` | No | `10` | Max requests per IP per minute on `/oauth/token` and `/oauth/par` (0 to disable) | 55 + | `AUTH_RATE_LIMIT_GLOBAL` | No | `100` | Max total requests per minute on `/oauth/token` and `/oauth/par` (0 to disable) | 52 56 53 57 ## How It Works 54 58 ··· 116 120 117 121 ## Key Rotation 118 122 123 + The proxy supports zero-downtime key rotation. During rotation, both old and new public keys are published in the JWKS so existing sessions bound to the old key continue to work. 124 + 119 125 1. Generate a new key pair with a new `kid` (e.g., `atproto-auth-2`) 120 - 2. Temporarily serve both old and new public keys in the JWKS 121 - 3. Deploy — the auth server will fetch the updated JWKS 122 - 4. After 24+ hours, remove the old key 123 - 5. Update `AUTH_PRIVATE_KEY` and `AUTH_KEY_ID` to the new key only 126 + 2. Set `AUTH_OLD_PRIVATE_KEY` and `AUTH_OLD_KEY_ID` to your current key values 127 + 3. Set `AUTH_PRIVATE_KEY` and `AUTH_KEY_ID` to the new key 128 + 4. Deploy — the JWKS now serves both keys; new assertions use the new key 129 + 5. After 24+ hours, remove `AUTH_OLD_PRIVATE_KEY` and `AUTH_OLD_KEY_ID` 130 + 131 + The active key (`AUTH_PRIVATE_KEY`) is always used for signing new assertions. The old key is only published in the JWKS so auth servers can still verify existing sessions. 124 132 125 133 ## Security Considerations 126 134 ··· 130 138 - **No token logging**: Token values, auth codes, and refresh tokens are never logged 131 139 - **HTTPS required**: The proxy must be served over HTTPS in production (handled automatically by Railway/Fly.io) 132 140 - **DPoP passthrough**: The proxy never sees DPoP private keys — proofs are between the device and auth server 141 + - **Rate limiting**: Per-IP and global rate limits on `/oauth/token` and `/oauth/par` (configurable, defaults to 10/min per IP, 100/min global) 133 142 - **Stateless**: No database, no user data stored — the only secret is the client signing key in an environment variable 134 143 135 144 ## License
+40 -10
config.go
··· 3 3 import ( 4 4 "fmt" 5 5 "os" 6 + "strconv" 6 7 ) 7 8 8 9 type Config struct { 9 - PrivateKeyPEM string 10 - ClientID string 11 - KeyID string 12 - Bind string 13 - AllowedOrigins string 10 + PrivateKeyPEM string 11 + ClientID string 12 + KeyID string 13 + OldPrivateKeyPEM string 14 + OldKeyID string 15 + Bind string 16 + AllowedOrigins string 17 + RateLimitPerIP int 18 + RateLimitGlobal int 14 19 } 15 20 16 21 func LoadConfig() (*Config, error) { ··· 29 34 keyID = "atproto-auth-1" 30 35 } 31 36 37 + oldPrivateKey := os.Getenv("AUTH_OLD_PRIVATE_KEY") 38 + oldKeyID := os.Getenv("AUTH_OLD_KEY_ID") 39 + 32 40 bind := os.Getenv("AUTH_BIND") 33 41 if bind == "" { 34 42 bind = ":8080" ··· 39 47 allowedOrigins = "*" 40 48 } 41 49 50 + rateLimitPerIP := 10 51 + if v := os.Getenv("AUTH_RATE_LIMIT_PER_IP"); v != "" { 52 + n, err := strconv.Atoi(v) 53 + if err != nil || n < 0 { 54 + return nil, fmt.Errorf("AUTH_RATE_LIMIT_PER_IP must be a non-negative integer") 55 + } 56 + rateLimitPerIP = n 57 + } 58 + 59 + rateLimitGlobal := 100 60 + if v := os.Getenv("AUTH_RATE_LIMIT_GLOBAL"); v != "" { 61 + n, err := strconv.Atoi(v) 62 + if err != nil || n < 0 { 63 + return nil, fmt.Errorf("AUTH_RATE_LIMIT_GLOBAL must be a non-negative integer") 64 + } 65 + rateLimitGlobal = n 66 + } 67 + 42 68 return &Config{ 43 - PrivateKeyPEM: privateKey, 44 - ClientID: clientID, 45 - KeyID: keyID, 46 - Bind: bind, 47 - AllowedOrigins: allowedOrigins, 69 + PrivateKeyPEM: privateKey, 70 + ClientID: clientID, 71 + KeyID: keyID, 72 + OldPrivateKeyPEM: oldPrivateKey, 73 + OldKeyID: oldKeyID, 74 + Bind: bind, 75 + AllowedOrigins: allowedOrigins, 76 + RateLimitPerIP: rateLimitPerIP, 77 + RateLimitGlobal: rateLimitGlobal, 48 78 }, nil 49 79 }
+26 -18
keys.go
··· 34 34 return ecKey, nil 35 35 } 36 36 37 - func BuildJWKS(privateKey *ecdsa.PrivateKey, kid string) ([]byte, error) { 38 - pubKey := privateKey.Public() 37 + type keyEntry struct { 38 + privateKey *ecdsa.PrivateKey 39 + kid string 40 + } 39 41 40 - jwkKey, err := jwk.FromRaw(pubKey) 41 - if err != nil { 42 - return nil, fmt.Errorf("failed to create JWK from public key: %w", err) 43 - } 42 + func BuildJWKS(keys []keyEntry) ([]byte, error) { 43 + set := jwk.NewSet() 44 44 45 - if err := jwkKey.Set(jwk.KeyIDKey, kid); err != nil { 46 - return nil, fmt.Errorf("failed to set kid: %w", err) 47 - } 48 - if err := jwkKey.Set(jwk.AlgorithmKey, jwa.ES256); err != nil { 49 - return nil, fmt.Errorf("failed to set alg: %w", err) 50 - } 51 - if err := jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil { 52 - return nil, fmt.Errorf("failed to set use: %w", err) 53 - } 45 + for _, k := range keys { 46 + pubKey := k.privateKey.Public() 47 + 48 + jwkKey, err := jwk.FromRaw(pubKey) 49 + if err != nil { 50 + return nil, fmt.Errorf("failed to create JWK from public key (kid=%s): %w", k.kid, err) 51 + } 54 52 55 - set := jwk.NewSet() 56 - if err := set.AddKey(jwkKey); err != nil { 57 - return nil, fmt.Errorf("failed to add key to set: %w", err) 53 + if err := jwkKey.Set(jwk.KeyIDKey, k.kid); err != nil { 54 + return nil, fmt.Errorf("failed to set kid: %w", err) 55 + } 56 + if err := jwkKey.Set(jwk.AlgorithmKey, jwa.ES256); err != nil { 57 + return nil, fmt.Errorf("failed to set alg: %w", err) 58 + } 59 + if err := jwkKey.Set(jwk.KeyUsageKey, "sig"); err != nil { 60 + return nil, fmt.Errorf("failed to set use: %w", err) 61 + } 62 + 63 + if err := set.AddKey(jwkKey); err != nil { 64 + return nil, fmt.Errorf("failed to add key to set: %w", err) 65 + } 58 66 } 59 67 60 68 jsonBytes, err := json.Marshal(set)
+49 -2
keys_test.go
··· 78 78 } 79 79 80 80 kid := "test-key-1" 81 - jwksBytes, err := BuildJWKS(key, kid) 81 + jwksBytes, err := BuildJWKS([]keyEntry{{privateKey: key, kid: kid}}) 82 82 if err != nil { 83 83 t.Fatalf("unexpected error: %v", err) 84 84 } ··· 123 123 } 124 124 } 125 125 126 + func TestBuildJWKS_MultipleKeys(t *testing.T) { 127 + pem1 := generateTestPEM(t) 128 + key1, err := ParsePrivateKey(pem1) 129 + if err != nil { 130 + t.Fatalf("failed to parse key1: %v", err) 131 + } 132 + 133 + pem2 := generateTestPEM(t) 134 + key2, err := ParsePrivateKey(pem2) 135 + if err != nil { 136 + t.Fatalf("failed to parse key2: %v", err) 137 + } 138 + 139 + jwksBytes, err := BuildJWKS([]keyEntry{ 140 + {privateKey: key1, kid: "new-key"}, 141 + {privateKey: key2, kid: "old-key"}, 142 + }) 143 + if err != nil { 144 + t.Fatalf("unexpected error: %v", err) 145 + } 146 + 147 + var jwks struct { 148 + Keys []struct { 149 + Kid string `json:"kid"` 150 + Kty string `json:"kty"` 151 + } `json:"keys"` 152 + } 153 + if err := json.Unmarshal(jwksBytes, &jwks); err != nil { 154 + t.Fatalf("failed to unmarshal JWKS: %v", err) 155 + } 156 + 157 + if len(jwks.Keys) != 2 { 158 + t.Fatalf("expected 2 keys, got %d", len(jwks.Keys)) 159 + } 160 + 161 + kids := map[string]bool{} 162 + for _, k := range jwks.Keys { 163 + kids[k.Kid] = true 164 + } 165 + if !kids["new-key"] { 166 + t.Error("expected kid=new-key in JWKS") 167 + } 168 + if !kids["old-key"] { 169 + t.Error("expected kid=old-key in JWKS") 170 + } 171 + } 172 + 126 173 func TestBuildJWKS_NoPrivateKey(t *testing.T) { 127 174 pemData := generateTestPEM(t) 128 175 key, err := ParsePrivateKey(pemData) ··· 130 177 t.Fatalf("failed to parse key: %v", err) 131 178 } 132 179 133 - jwksBytes, err := BuildJWKS(key, "test-key") 180 + jwksBytes, err := BuildJWKS([]keyEntry{{privateKey: key, kid: "test-key"}}) 134 181 if err != nil { 135 182 t.Fatalf("unexpected error: %v", err) 136 183 }
+23 -3
main.go
··· 21 21 log.Fatalf("failed to parse private key: %v", err) 22 22 } 23 23 24 - jwksJSON, err := BuildJWKS(privateKey, cfg.KeyID) 24 + keys := []keyEntry{{privateKey: privateKey, kid: cfg.KeyID}} 25 + 26 + if cfg.OldPrivateKeyPEM != "" { 27 + oldKey, err := ParsePrivateKey(cfg.OldPrivateKeyPEM) 28 + if err != nil { 29 + log.Fatalf("failed to parse old private key: %v", err) 30 + } 31 + oldKID := cfg.OldKeyID 32 + if oldKID == "" { 33 + log.Fatal("AUTH_OLD_KEY_ID is required when AUTH_OLD_PRIVATE_KEY is set") 34 + } 35 + if oldKID == cfg.KeyID { 36 + log.Fatal("AUTH_OLD_KEY_ID must differ from AUTH_KEY_ID") 37 + } 38 + keys = append(keys, keyEntry{privateKey: oldKey, kid: oldKID}) 39 + log.Printf("key rotation: serving both %s (active) and %s (old) in JWKS", cfg.KeyID, oldKID) 40 + } 41 + 42 + jwksJSON, err := BuildJWKS(keys) 25 43 if err != nil { 26 44 log.Fatalf("failed to build JWKS: %v", err) 27 45 } ··· 31 49 log.Fatalf("failed to create signer: %v", err) 32 50 } 33 51 52 + rl := newRateLimiter(cfg.RateLimitPerIP, cfg.RateLimitGlobal) 53 + 34 54 mux := http.NewServeMux() 35 55 36 56 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 37 - mux.HandleFunc("POST /oauth/token", HandleToken(signingKey, cfg.ClientID)) 38 - mux.HandleFunc("POST /oauth/par", HandlePAR(signingKey, cfg.ClientID)) 57 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, HandleToken(signingKey, cfg.ClientID))) 58 + mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, HandlePAR(signingKey, cfg.ClientID))) 39 59 mux.HandleFunc("GET /health", HandleHealth) 40 60 41 61 handler := CORSMiddleware(cfg.AllowedOrigins, mux)
+9 -3
main_test.go
··· 12 12 13 13 func setupTestServer(t *testing.T) (*httptest.Server, func()) { 14 14 t.Helper() 15 + return setupTestServerWithRateLimit(t, 0, 0) 16 + } 17 + 18 + func setupTestServerWithRateLimit(t *testing.T, perIP, global int) (*httptest.Server, func()) { 19 + t.Helper() 15 20 16 21 pemData := generateTestPEM(t) 17 22 key, err := ParsePrivateKey(pemData) ··· 19 24 t.Fatalf("failed to parse key: %v", err) 20 25 } 21 26 22 - jwksJSON, err := BuildJWKS(key, "test-kid") 27 + jwksJSON, err := BuildJWKS([]keyEntry{{privateKey: key, kid: "test-kid"}}) 23 28 if err != nil { 24 29 t.Fatalf("failed to build JWKS: %v", err) 25 30 } ··· 30 35 } 31 36 32 37 clientID := "https://example.com/oauth/client-metadata.json" 38 + rl := newRateLimiter(perIP, global) 33 39 34 40 mux := http.NewServeMux() 35 41 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 36 - mux.HandleFunc("POST /oauth/token", HandleToken(signingKey, clientID)) 37 - mux.HandleFunc("POST /oauth/par", HandlePAR(signingKey, clientID)) 42 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, HandleToken(signingKey, clientID))) 43 + mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, HandlePAR(signingKey, clientID))) 38 44 mux.HandleFunc("GET /health", HandleHealth) 39 45 40 46 handler := CORSMiddleware("*", mux)
+101
ratelimit.go
··· 1 + package main 2 + 3 + import ( 4 + "net" 5 + "net/http" 6 + "sync" 7 + "time" 8 + ) 9 + 10 + type rateLimiter struct { 11 + perIP int 12 + global int 13 + window time.Duration 14 + mu sync.Mutex 15 + buckets map[string]*bucket 16 + globalCt int 17 + windowAt time.Time 18 + } 19 + 20 + type bucket struct { 21 + count int 22 + windowAt time.Time 23 + } 24 + 25 + func newRateLimiter(perIP, global int) *rateLimiter { 26 + return &rateLimiter{ 27 + perIP: perIP, 28 + global: global, 29 + window: time.Minute, 30 + buckets: make(map[string]*bucket), 31 + windowAt: time.Now().Truncate(time.Minute), 32 + } 33 + } 34 + 35 + func (rl *rateLimiter) allow(ip string) bool { 36 + if rl.perIP == 0 && rl.global == 0 { 37 + return true 38 + } 39 + 40 + rl.mu.Lock() 41 + defer rl.mu.Unlock() 42 + 43 + now := time.Now() 44 + currentWindow := now.Truncate(rl.window) 45 + 46 + if currentWindow.After(rl.windowAt) { 47 + rl.windowAt = currentWindow 48 + rl.globalCt = 0 49 + rl.buckets = make(map[string]*bucket) 50 + } 51 + 52 + if rl.global > 0 { 53 + if rl.globalCt >= rl.global { 54 + return false 55 + } 56 + } 57 + 58 + if rl.perIP > 0 { 59 + b, ok := rl.buckets[ip] 60 + if !ok { 61 + b = &bucket{windowAt: currentWindow} 62 + rl.buckets[ip] = b 63 + } 64 + if b.count >= rl.perIP { 65 + return false 66 + } 67 + b.count++ 68 + } 69 + 70 + rl.globalCt++ 71 + return true 72 + } 73 + 74 + func RateLimitMiddleware(rl *rateLimiter, next http.HandlerFunc) http.HandlerFunc { 75 + return func(w http.ResponseWriter, r *http.Request) { 76 + ip := clientIP(r) 77 + if !rl.allow(ip) { 78 + w.Header().Set("Retry-After", "60") 79 + http.Error(w, `{"error":"too_many_requests","error_description":"rate limit exceeded"}`, http.StatusTooManyRequests) 80 + return 81 + } 82 + next(w, r) 83 + } 84 + } 85 + 86 + func clientIP(r *http.Request) string { 87 + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { 88 + // First entry is the original client 89 + for i := 0; i < len(xff); i++ { 90 + if xff[i] == ',' { 91 + return xff[:i] 92 + } 93 + } 94 + return xff 95 + } 96 + host, _, err := net.SplitHostPort(r.RemoteAddr) 97 + if err != nil { 98 + return r.RemoteAddr 99 + } 100 + return host 101 + }
+148
ratelimit_test.go
··· 1 + package main 2 + 3 + import ( 4 + "net/http" 5 + "strings" 6 + "testing" 7 + ) 8 + 9 + func TestRateLimiter_PerIP(t *testing.T) { 10 + rl := newRateLimiter(3, 0) 11 + 12 + for i := 0; i < 3; i++ { 13 + if !rl.allow("1.2.3.4") { 14 + t.Fatalf("request %d should be allowed", i+1) 15 + } 16 + } 17 + 18 + if rl.allow("1.2.3.4") { 19 + t.Error("4th request from same IP should be blocked") 20 + } 21 + 22 + // Different IP should still be allowed 23 + if !rl.allow("5.6.7.8") { 24 + t.Error("request from different IP should be allowed") 25 + } 26 + } 27 + 28 + func TestRateLimiter_Global(t *testing.T) { 29 + rl := newRateLimiter(0, 5) 30 + 31 + for i := 0; i < 5; i++ { 32 + if !rl.allow("1.2.3.4") { 33 + t.Fatalf("request %d should be allowed", i+1) 34 + } 35 + } 36 + 37 + // Global limit hit, even from a new IP 38 + if rl.allow("9.9.9.9") { 39 + t.Error("request should be blocked by global limit") 40 + } 41 + } 42 + 43 + func TestRateLimiter_Combined(t *testing.T) { 44 + rl := newRateLimiter(2, 5) 45 + 46 + // First IP gets 2 47 + if !rl.allow("1.1.1.1") { 48 + t.Fatal("should be allowed") 49 + } 50 + if !rl.allow("1.1.1.1") { 51 + t.Fatal("should be allowed") 52 + } 53 + if rl.allow("1.1.1.1") { 54 + t.Error("per-IP limit should block") 55 + } 56 + 57 + // Second IP gets 2 more (global at 4 now, 2 counted + 2 more) 58 + // Wait, the per-IP block above didn't increment global. Let me reconsider. 59 + // Global count: 2 (from 1.1.1.1's allowed requests) 60 + if !rl.allow("2.2.2.2") { 61 + t.Fatal("should be allowed") 62 + } 63 + if !rl.allow("2.2.2.2") { 64 + t.Fatal("should be allowed") 65 + } 66 + // Global count: 4. Third IP gets 1 more before global limit. 67 + if !rl.allow("3.3.3.3") { 68 + t.Fatal("should be allowed") 69 + } 70 + // Global count: 5. Now global limit hit. 71 + if rl.allow("4.4.4.4") { 72 + t.Error("global limit should block") 73 + } 74 + } 75 + 76 + func TestRateLimiter_Disabled(t *testing.T) { 77 + rl := newRateLimiter(0, 0) 78 + 79 + for i := 0; i < 1000; i++ { 80 + if !rl.allow("1.2.3.4") { 81 + t.Fatalf("request %d should be allowed when rate limiting is disabled", i+1) 82 + } 83 + } 84 + } 85 + 86 + func TestRateLimitMiddleware_Returns429(t *testing.T) { 87 + srv, cleanup := setupTestServerWithRateLimit(t, 2, 0) 88 + defer cleanup() 89 + 90 + body := `{"token_endpoint":"https://bsky.social/oauth/token","issuer":"https://bsky.social","grant_type":"authorization_code","code":"test"}` 91 + 92 + // First 2 should succeed (400 from validation, but not 429) 93 + for i := 0; i < 2; i++ { 94 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 95 + if err != nil { 96 + t.Fatalf("request %d failed: %v", i+1, err) 97 + } 98 + resp.Body.Close() 99 + if resp.StatusCode == http.StatusTooManyRequests { 100 + t.Fatalf("request %d should not be rate limited", i+1) 101 + } 102 + } 103 + 104 + // 3rd should be rate limited 105 + resp, err := http.Post(srv.URL+"/oauth/token", "application/json", strings.NewReader(body)) 106 + if err != nil { 107 + t.Fatalf("request failed: %v", err) 108 + } 109 + defer resp.Body.Close() 110 + 111 + if resp.StatusCode != http.StatusTooManyRequests { 112 + t.Errorf("expected 429, got %d", resp.StatusCode) 113 + } 114 + 115 + if resp.Header.Get("Retry-After") != "60" { 116 + t.Errorf("expected Retry-After: 60, got %s", resp.Header.Get("Retry-After")) 117 + } 118 + } 119 + 120 + func TestClientIP(t *testing.T) { 121 + tests := []struct { 122 + name string 123 + remoteAddr string 124 + xff string 125 + expected string 126 + }{ 127 + {"remote addr with port", "1.2.3.4:12345", "", "1.2.3.4"}, 128 + {"remote addr without port", "1.2.3.4", "", "1.2.3.4"}, 129 + {"x-forwarded-for single", "9.9.9.9:1234", "1.2.3.4", "1.2.3.4"}, 130 + {"x-forwarded-for multiple", "9.9.9.9:1234", "1.2.3.4, 5.6.7.8", "1.2.3.4"}, 131 + } 132 + 133 + for _, tt := range tests { 134 + t.Run(tt.name, func(t *testing.T) { 135 + r := &http.Request{ 136 + RemoteAddr: tt.remoteAddr, 137 + Header: http.Header{}, 138 + } 139 + if tt.xff != "" { 140 + r.Header.Set("X-Forwarded-For", tt.xff) 141 + } 142 + got := clientIP(r) 143 + if got != tt.expected { 144 + t.Errorf("clientIP() = %q, want %q", got, tt.expected) 145 + } 146 + }) 147 + } 148 + }