Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions middleware/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"encoding/json"
"log/slog"
"math"
"net"
"net/http"
"strconv"
Expand Down Expand Up @@ -42,7 +43,7 @@ func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
// Check rate limit
resp, err := m.limiter.Allow(ip)
if err != nil {
m.logger.Error("rate limit check failed",
m.logger.Error("rate limit check failed",
"error", err,
"ip", ip,
)
Expand All @@ -51,9 +52,9 @@ func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
}

if !resp.Allowed {
// Calculate retry after in seconds
retryAfterSecs := int(time.Until(resp.RetryAfter).Seconds())
// Calculate retry after in seconds, rounding up and never negative
retryAfterSecs := int(math.Ceil(math.Max(0, time.Until(resp.RetryAfter).Seconds())))

// Set headers
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", strconv.Itoa(retryAfterSecs))
Expand Down Expand Up @@ -94,7 +95,7 @@ func getClientIP(r *http.Request) string {
return ips.String()
}
}

// Extract from RemoteAddr
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
return ip
Expand Down
44 changes: 44 additions & 0 deletions middleware/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -52,6 +53,37 @@ func TestRateLimitMiddleware(t *testing.T) {
}
}

func TestRetryAfterHeaderNonNegative(t *testing.T) {
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
storage := &mockPastStorage{}
limiter := ratelimiter.New(storage)
middleware := NewRateLimitMiddleware(limiter, logger)

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

// exceed limit so middleware attempts to return Retry-After header
storage.count = 101
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
middleware.Handler(handler).ServeHTTP(rec, req)

hdr := rec.Header().Get("Retry-After")
if hdr == "" {
t.Fatalf("expected Retry-After header to be set")
}

val, err := strconv.Atoi(hdr)
if err != nil {
t.Fatalf("invalid Retry-After header: %v", err)
}

if val < 0 {
t.Errorf("expected Retry-After to be non-negative, got %d", val)
}
}

// Mock storage for testing
type mockStorage struct {
count int
Expand Down Expand Up @@ -81,3 +113,15 @@ func (m *mockStorage) Reset(key string) error {
m.count = 0
return nil
}

// mockPastStorage returns a past Retry-After time when blocked
type mockPastStorage struct {
mockStorage
}

func (m *mockPastStorage) IsBlocked(key string) (bool, time.Time, error) {
if m.count > 100 {
return true, time.Now().Add(-time.Minute), nil
}
return false, time.Time{}, nil
}