package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestTokenBucket_AllowsBurst(t *testing.T) { tb := NewTokenBucket(1.0, 5) // 1/sec, burst 5 handler := tb.LimitFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Should allow burst of 5 requests for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, w.Code) } } // 6th request should be rate limited req := httptest.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("request 6: expected 429, got %d", w.Code) } } func TestTokenBucket_DifferentIPs(t *testing.T) { tb := NewTokenBucket(1.0, 2) // 1/sec, burst 2 handler := tb.LimitFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Exhaust IP1's bucket for i := 0; i < 2; i++ { req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Forwarded-For", "1.2.3.4") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("ip1 request %d: expected 200, got %d", i+1, w.Code) } } // IP1 should now be limited req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Forwarded-For", "1.2.3.4") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("ip1 request 3: expected 429, got %d", w.Code) } // IP2 should still work req = httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Forwarded-For", "5.6.7.8") w = httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("ip2 request 1: expected 200, got %d", w.Code) } }