package middleware import ( "log/slog" "net/http" "sync" "time" ) // TokenBucket implements a simple per-IP token bucket rate limiter. type TokenBucket struct { mu sync.Mutex buckets map[string]*bucket rate float64 // tokens per second burst int // max tokens } type bucket struct { tokens float64 lastTime time.Time } // NewTokenBucket creates a rate limiter allowing rate requests per second with burst capacity. func NewTokenBucket(rate float64, burst int) *TokenBucket { tb := &TokenBucket{ buckets: make(map[string]*bucket), rate: rate, burst: burst, } // Periodically clean up stale buckets go tb.cleanup() return tb } func (tb *TokenBucket) allow(key string) bool { tb.mu.Lock() defer tb.mu.Unlock() b, ok := tb.buckets[key] if !ok { b = &bucket{tokens: float64(tb.burst), lastTime: time.Now()} tb.buckets[key] = b } now := time.Now() elapsed := now.Sub(b.lastTime).Seconds() b.tokens += elapsed * tb.rate if b.tokens > float64(tb.burst) { b.tokens = float64(tb.burst) } b.lastTime = now if b.tokens < 1 { return false } b.tokens-- return true } func (tb *TokenBucket) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { tb.mu.Lock() cutoff := time.Now().Add(-10 * time.Minute) for key, b := range tb.buckets { if b.lastTime.Before(cutoff) { delete(tb.buckets, key) } } tb.mu.Unlock() } } // Limit wraps an http.Handler with rate limiting. func (tb *TokenBucket) Limit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := r.Header.Get("X-Forwarded-For") if ip == "" { ip = r.RemoteAddr } if !tb.allow(ip) { slog.Warn("rate limit exceeded", "ip", ip, "path", r.URL.Path) w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", "10") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error":"rate limit exceeded, try again later"}`)) return } next.ServeHTTP(w, r) }) } // LimitFunc wraps an http.HandlerFunc with rate limiting. func (tb *TokenBucket) LimitFunc(next http.HandlerFunc) http.HandlerFunc { limited := tb.Limit(http.HandlerFunc(next)) return limited.ServeHTTP }