package main import ( "net" "net/http" "strings" "sync" "time" ) type rateLimiter struct { perIP int global int window time.Duration mu sync.Mutex buckets map[string]*bucket globalCt int windowAt time.Time } type bucket struct { count int windowAt time.Time } func newRateLimiter(perIP, global int) *rateLimiter { return &rateLimiter{ perIP: perIP, global: global, window: time.Minute, buckets: make(map[string]*bucket), windowAt: time.Now().Truncate(time.Minute), } } func (rl *rateLimiter) allow(ip string) bool { if rl.perIP == 0 && rl.global == 0 { return true } rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() currentWindow := now.Truncate(rl.window) if currentWindow.After(rl.windowAt) { rl.windowAt = currentWindow rl.globalCt = 0 rl.buckets = make(map[string]*bucket) } if rl.global > 0 { if rl.globalCt >= rl.global { return false } } if rl.perIP > 0 { b, ok := rl.buckets[ip] if !ok { b = &bucket{windowAt: currentWindow} rl.buckets[ip] = b } if b.count >= rl.perIP { return false } b.count++ } rl.globalCt++ return true } func RateLimitMiddleware(rl *rateLimiter, trustProxyHeaders bool, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := clientIP(r, trustProxyHeaders) if !rl.allow(ip) { w.Header().Set("Retry-After", "60") writeJSONError(w, http.StatusTooManyRequests, "too_many_requests", "rate limit exceeded") return } next(w, r) } } func clientIP(r *http.Request, trustProxyHeaders bool) string { if trustProxyHeaders { if forwarded := forwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" { return forwarded } if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); net.ParseIP(realIP) != nil { return realIP } } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func forwardedIP(headerValue string) string { if headerValue == "" { return "" } first := headerValue if idx := strings.IndexRune(headerValue, ','); idx >= 0 { first = headerValue[:idx] } first = strings.TrimSpace(first) if net.ParseIP(first) == nil { return "" } return first }