Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

better refresh handling

+181 -3
+140
cache.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "crypto/sha256" 6 + "encoding/hex" 7 + "sync" 8 + "time" 9 + ) 10 + 11 + const ( 12 + refreshCacheTTL = 10 * time.Minute 13 + refreshCacheMaxSize = 10_000 14 + ) 15 + 16 + // cachedTokenResponse is the portion of a successful /oauth/token proxy 17 + // response we replay to clients whose original response was lost in transit 18 + // (TCP reset, app backgrounded before reading the body, network blip between 19 + // proxy and client, etc.). Because AT Protocol refresh tokens rotate on use, 20 + // the client only has one shot — if the original response never lands, the 21 + // token is consumed on the authorization server but the client is stuck with 22 + // the old value and every retry gets "invalid_grant: Refresh token replayed". 23 + type cachedTokenResponse struct { 24 + response *upstreamResponse 25 + usedKeyID string 26 + } 27 + 28 + type refreshCacheEntry struct { 29 + ready chan struct{} 30 + result *cachedTokenResponse 31 + expires time.Time 32 + } 33 + 34 + // refreshCache combines a short-TTL idempotency cache with a single-flight 35 + // gate. Concurrent requests for the same refresh token coalesce onto one 36 + // upstream call; retries after a lost response pick up the cached result. 37 + // Negative outcomes (4xx/5xx from upstream, or upstream transport failures) 38 + // are not cached — the authorization server is authoritative on whether a 39 + // given token is still alive. 40 + type refreshCache struct { 41 + mu sync.Mutex 42 + entries map[string]*refreshCacheEntry 43 + } 44 + 45 + func newRefreshCache() *refreshCache { 46 + return &refreshCache{entries: make(map[string]*refreshCacheEntry)} 47 + } 48 + 49 + func refreshCacheKey(refreshToken string) string { 50 + sum := sha256.Sum256([]byte(refreshToken)) 51 + return hex.EncodeToString(sum[:]) 52 + } 53 + 54 + // acquire returns the cache entry for the given refresh token and reports 55 + // whether the caller is the leader (must perform the upstream call and then 56 + // call finalize or release) or a follower (should wait on the entry). 57 + func (c *refreshCache) acquire(refreshToken string) (entry *refreshCacheEntry, isLeader bool) { 58 + key := refreshCacheKey(refreshToken) 59 + c.mu.Lock() 60 + defer c.mu.Unlock() 61 + c.sweepLocked() 62 + 63 + if existing, ok := c.entries[key]; ok { 64 + return existing, false 65 + } 66 + 67 + entry = &refreshCacheEntry{ 68 + ready: make(chan struct{}), 69 + expires: time.Now().Add(refreshCacheTTL), 70 + } 71 + c.entries[key] = entry 72 + return entry, true 73 + } 74 + 75 + // finalize records a successful upstream response and wakes any followers. 76 + // The entry remains in the cache for its TTL so later retries whose original 77 + // response was lost can pick up the same rotated token. 78 + func (c *refreshCache) finalize(entry *refreshCacheEntry, result *cachedTokenResponse) { 79 + c.mu.Lock() 80 + entry.result = result 81 + c.mu.Unlock() 82 + close(entry.ready) 83 + } 84 + 85 + // release removes a non-cacheable entry and wakes any followers with no 86 + // cached result so they can fall through to a fresh upstream attempt. 87 + func (c *refreshCache) release(refreshToken string, entry *refreshCacheEntry) { 88 + key := refreshCacheKey(refreshToken) 89 + c.mu.Lock() 90 + if current, ok := c.entries[key]; ok && current == entry { 91 + delete(c.entries, key) 92 + } 93 + c.mu.Unlock() 94 + close(entry.ready) 95 + } 96 + 97 + // wait blocks until the entry is finalized or released, honoring context 98 + // cancellation. Returns nil if the entry was released or the context ended. 99 + func (c *refreshCache) wait(ctx context.Context, entry *refreshCacheEntry) *cachedTokenResponse { 100 + select { 101 + case <-entry.ready: 102 + case <-ctx.Done(): 103 + return nil 104 + } 105 + c.mu.Lock() 106 + defer c.mu.Unlock() 107 + return entry.result 108 + } 109 + 110 + // sweepLocked drops expired entries that have already settled. In-flight 111 + // entries are left alone regardless of TTL — the leader is responsible for 112 + // calling finalize or release, after which the next sweep will reap them. 113 + // Called under c.mu. 114 + func (c *refreshCache) sweepLocked() { 115 + now := time.Now() 116 + for key, entry := range c.entries { 117 + select { 118 + case <-entry.ready: 119 + if now.After(entry.expires) { 120 + delete(c.entries, key) 121 + } 122 + default: 123 + } 124 + } 125 + if len(c.entries) < refreshCacheMaxSize { 126 + return 127 + } 128 + // Bound memory under pathological load: drop settled entries until we 129 + // are back under the cap. In-flight entries are never evicted. 130 + for key, entry := range c.entries { 131 + if len(c.entries) < refreshCacheMaxSize { 132 + break 133 + } 134 + select { 135 + case <-entry.ready: 136 + delete(c.entries, key) 137 + default: 138 + } 139 + } 140 + }
+38 -1
handler_token.go
··· 18 18 RefreshToken string `json:"refresh_token,omitempty"` 19 19 } 20 20 21 - func HandleToken(signers *SignerSet, clientID string) http.HandlerFunc { 21 + func HandleToken(signers *SignerSet, clientID string, cache *refreshCache) http.HandlerFunc { 22 22 return func(w http.ResponseWriter, r *http.Request) { 23 23 var req tokenRequest 24 24 if err := json.NewDecoder(r.Body).Decode(&req); err != nil { ··· 42 42 if err := ValidateTokenEndpointForIssuer(r.Context(), req.Issuer, req.TokenEndpoint); err != nil { 43 43 writeAPIError(w, err) 44 44 return 45 + } 46 + 47 + // Idempotency: refresh tokens rotate single-use. If a client's original 48 + // response was lost in transit, the token is spent upstream but the 49 + // client still holds the old value. Coalesce concurrent duplicates and 50 + // briefly cache successful upstream responses so retries recover the 51 + // rotated token instead of getting "Refresh token replayed". 52 + var cacheSlot *refreshCacheEntry 53 + if req.GrantType == "refresh_token" && req.RefreshToken != "" { 54 + entry, isLeader := cache.acquire(req.RefreshToken) 55 + if !isLeader { 56 + if cached := cache.wait(r.Context(), entry); cached != nil { 57 + w.Header().Set(authProxyKeyIDHeader, cached.usedKeyID) 58 + if err := WriteProxiedResponse(w, cached.response); err != nil { 59 + log.Printf("failed to write proxied response: %v", err) 60 + } 61 + return 62 + } 63 + // Leader released (non-cacheable outcome) or context ended. 64 + // Fall through and make our own upstream attempt without 65 + // holding a cache slot; avoids looping on contention. 66 + } else { 67 + cacheSlot = entry 68 + defer func() { 69 + if cacheSlot != nil { 70 + cache.release(req.RefreshToken, cacheSlot) 71 + } 72 + }() 73 + } 45 74 } 46 75 47 76 candidateKeyIDs, err := signers.CandidateKeyIDs(req.KeyID) ··· 103 132 } 104 133 105 134 break 135 + } 136 + 137 + // Finalize the cache entry before writing to the client so concurrent 138 + // followers can unblock during the write. Cache only 2xx responses — 139 + // negative outcomes are authoritative and shouldn't mask a later retry. 140 + if cacheSlot != nil && proxied != nil && proxied.statusCode >= 200 && proxied.statusCode < 300 { 141 + cache.finalize(cacheSlot, &cachedTokenResponse{response: proxied, usedKeyID: usedKeyID}) 142 + cacheSlot = nil 106 143 } 107 144 108 145 w.Header().Set(authProxyKeyIDHeader, usedKeyID)
+2 -1
main.go
··· 50 50 } 51 51 52 52 rl := newRateLimiter(cfg.RateLimitPerIP, cfg.RateLimitGlobal) 53 + tokenCache := newRefreshCache() 53 54 54 55 mux := http.NewServeMux() 55 56 56 57 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 57 - mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, cfg.TrustProxyHeaders, HandleToken(signers, cfg.ClientID))) 58 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, cfg.TrustProxyHeaders, HandleToken(signers, cfg.ClientID, tokenCache))) 58 59 mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, cfg.TrustProxyHeaders, HandlePAR(signers, cfg.ClientID))) 59 60 mux.HandleFunc("GET /health", HandleHealth) 60 61
+1 -1
main_test.go
··· 58 58 59 59 mux := http.NewServeMux() 60 60 mux.HandleFunc("GET /.well-known/jwks.json", HandleJWKS(jwksJSON)) 61 - mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, trustProxyHeaders, HandleToken(signers, testClientID))) 61 + mux.HandleFunc("POST /oauth/token", RateLimitMiddleware(rl, trustProxyHeaders, HandleToken(signers, testClientID, newRefreshCache()))) 62 62 mux.HandleFunc("POST /oauth/par", RateLimitMiddleware(rl, trustProxyHeaders, HandlePAR(signers, testClientID))) 63 63 mux.HandleFunc("GET /health", HandleHealth) 64 64