From 61dea2f2f41af1213340c45c543e0706933083b4 Mon Sep 17 00:00:00 2001 From: Wesley Willians Date: Thu, 5 Jun 2025 16:59:19 -0400 Subject: [PATCH] Ensure positive Retry-After header --- middleware/http.go | 11 ++++++----- middleware/http_test.go | 44 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/middleware/http.go b/middleware/http.go index a3c812d..e4709a1 100644 --- a/middleware/http.go +++ b/middleware/http.go @@ -3,6 +3,7 @@ package middleware import ( "encoding/json" "log/slog" + "math" "net" "net/http" "strconv" @@ -42,7 +43,7 @@ func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler { // Check rate limit resp, err := m.limiter.Allow(ip) if err != nil { - m.logger.Error("rate limit check failed", + m.logger.Error("rate limit check failed", "error", err, "ip", ip, ) @@ -51,9 +52,9 @@ func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler { } if !resp.Allowed { - // Calculate retry after in seconds - retryAfterSecs := int(time.Until(resp.RetryAfter).Seconds()) - + // Calculate retry after in seconds, rounding up and never negative + retryAfterSecs := int(math.Ceil(math.Max(0, time.Until(resp.RetryAfter).Seconds()))) + // Set headers w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", strconv.Itoa(retryAfterSecs)) @@ -94,7 +95,7 @@ func getClientIP(r *http.Request) string { return ips.String() } } - + // Extract from RemoteAddr ip, _, _ := net.SplitHostPort(r.RemoteAddr) return ip diff --git a/middleware/http_test.go b/middleware/http_test.go index cce4b09..6c0e7e3 100644 --- a/middleware/http_test.go +++ b/middleware/http_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -52,6 +53,37 @@ func TestRateLimitMiddleware(t *testing.T) { } } +func TestRetryAfterHeaderNonNegative(t *testing.T) { + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + storage := &mockPastStorage{} + limiter := ratelimiter.New(storage) + middleware := NewRateLimitMiddleware(limiter, logger) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // exceed limit so middleware attempts to return Retry-After header + storage.count = 101 + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + middleware.Handler(handler).ServeHTTP(rec, req) + + hdr := rec.Header().Get("Retry-After") + if hdr == "" { + t.Fatalf("expected Retry-After header to be set") + } + + val, err := strconv.Atoi(hdr) + if err != nil { + t.Fatalf("invalid Retry-After header: %v", err) + } + + if val < 0 { + t.Errorf("expected Retry-After to be non-negative, got %d", val) + } +} + // Mock storage for testing type mockStorage struct { count int @@ -81,3 +113,15 @@ func (m *mockStorage) Reset(key string) error { m.count = 0 return nil } + +// mockPastStorage returns a past Retry-After time when blocked +type mockPastStorage struct { + mockStorage +} + +func (m *mockPastStorage) IsBlocked(key string) (bool, time.Time, error) { + if m.count > 100 { + return true, time.Now().Add(-time.Minute), nil + } + return false, time.Time{}, nil +}