Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
1package main
2
3import (
4 "net"
5 "net/http"
6 "strings"
7 "sync"
8 "time"
9)
10
11type rateLimiter struct {
12 perIP int
13 global int
14 window time.Duration
15 mu sync.Mutex
16 buckets map[string]*bucket
17 globalCt int
18 windowAt time.Time
19}
20
21type bucket struct {
22 count int
23 windowAt time.Time
24}
25
26func newRateLimiter(perIP, global int) *rateLimiter {
27 return &rateLimiter{
28 perIP: perIP,
29 global: global,
30 window: time.Minute,
31 buckets: make(map[string]*bucket),
32 windowAt: time.Now().Truncate(time.Minute),
33 }
34}
35
36func (rl *rateLimiter) allow(ip string) bool {
37 if rl.perIP == 0 && rl.global == 0 {
38 return true
39 }
40
41 rl.mu.Lock()
42 defer rl.mu.Unlock()
43
44 now := time.Now()
45 currentWindow := now.Truncate(rl.window)
46
47 if currentWindow.After(rl.windowAt) {
48 rl.windowAt = currentWindow
49 rl.globalCt = 0
50 rl.buckets = make(map[string]*bucket)
51 }
52
53 if rl.global > 0 {
54 if rl.globalCt >= rl.global {
55 return false
56 }
57 }
58
59 if rl.perIP > 0 {
60 b, ok := rl.buckets[ip]
61 if !ok {
62 b = &bucket{windowAt: currentWindow}
63 rl.buckets[ip] = b
64 }
65 if b.count >= rl.perIP {
66 return false
67 }
68 b.count++
69 }
70
71 rl.globalCt++
72 return true
73}
74
75func RateLimitMiddleware(rl *rateLimiter, trustProxyHeaders bool, next http.HandlerFunc) http.HandlerFunc {
76 return func(w http.ResponseWriter, r *http.Request) {
77 ip := clientIP(r, trustProxyHeaders)
78 if !rl.allow(ip) {
79 w.Header().Set("Retry-After", "60")
80 writeJSONError(w, http.StatusTooManyRequests, "too_many_requests", "rate limit exceeded")
81 return
82 }
83 next(w, r)
84 }
85}
86
87func clientIP(r *http.Request, trustProxyHeaders bool) string {
88 if trustProxyHeaders {
89 if forwarded := forwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" {
90 return forwarded
91 }
92 if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); net.ParseIP(realIP) != nil {
93 return realIP
94 }
95 }
96
97 host, _, err := net.SplitHostPort(r.RemoteAddr)
98 if err != nil {
99 return r.RemoteAddr
100 }
101 return host
102}
103
104func forwardedIP(headerValue string) string {
105 if headerValue == "" {
106 return ""
107 }
108
109 first := headerValue
110 if idx := strings.IndexRune(headerValue, ','); idx >= 0 {
111 first = headerValue[:idx]
112 }
113 first = strings.TrimSpace(first)
114 if net.ParseIP(first) == nil {
115 return ""
116 }
117 return first
118}