Skip to content

Commit 2b1473c

Browse files
committed
Add API error model, auth, and rate limiting
1 parent d057810 commit 2b1473c

File tree

3 files changed

+200
-21
lines changed

3 files changed

+200
-21
lines changed

api/httpapi/httpapi.go

Lines changed: 160 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ package httpapi
22

33
import (
44
"encoding/json"
5+
"net"
56
"net/http"
67
"strconv"
8+
"strings"
9+
"sync"
10+
"time"
711

812
wsadapter "gamifykit/adapters/websocket"
913
"gamifykit/core"
@@ -17,6 +21,14 @@ type Options struct {
1721
PathPrefix string
1822
// AllowCORSOrigin, if non-empty, enables basic CORS with the given origin (use "*" for any).
1923
AllowCORSOrigin string
24+
// APIKeys, if non-empty, enables static API key auth via Authorization: Bearer or X-API-Key.
25+
APIKeys []string
26+
// RateLimitEnabled toggles rate limiting.
27+
RateLimitEnabled bool
28+
// RateLimitRPM is the allowed requests per minute per client key.
29+
RateLimitRPM int
30+
// RateLimitBurst defines burst capacity.
31+
RateLimitBurst int
2032
}
2133

2234
// NewMux builds an http.Handler exposing a minimal Gamify REST API and WebSocket stream.
@@ -42,49 +54,74 @@ func NewMux(svc *engine.GamifyService, hub *realtime.Hub, opts Options) http.Han
4254
// Users API
4355
mux.HandleFunc(withPrefix(opts.PathPrefix, "/users/"), func(w http.ResponseWriter, r *http.Request) {
4456
if r.Method != http.MethodGet && r.Method != http.MethodPost {
45-
http.NotFound(w, r)
57+
writeError(w, http.StatusNotFound, "not_found", "route not found", nil)
4658
return
4759
}
4860
parts := split(r.URL.Path, '/')
4961
if len(parts) < 2 {
50-
http.NotFound(w, r)
62+
writeError(w, http.StatusNotFound, "not_found", "route not found", nil)
63+
return
64+
}
65+
user, err := core.NormalizeUserID(core.UserID(parts[1]))
66+
if err != nil {
67+
writeError(w, http.StatusBadRequest, "invalid_user", err.Error(), nil)
5168
return
5269
}
53-
user := core.UserID(parts[1])
5470
switch r.Method {
5571
case http.MethodPost:
5672
if len(parts) >= 3 && parts[2] == "points" {
5773
metric := core.Metric(r.URL.Query().Get("metric"))
5874
if metric == "" {
5975
metric = core.MetricXP
6076
}
61-
delta, _ := strconv.ParseInt(r.URL.Query().Get("delta"), 10, 64)
77+
delta, err := strconv.ParseInt(r.URL.Query().Get("delta"), 10, 64)
78+
if err != nil {
79+
writeError(w, http.StatusBadRequest, "invalid_delta", "delta must be an integer", nil)
80+
return
81+
}
6282
total, err := svc.AddPoints(r.Context(), user, metric, delta)
63-
writeJSON(w, map[string]any{"total": total, "err": errString(err)})
83+
if err != nil {
84+
writeError(w, http.StatusBadRequest, "invalid_input", err.Error(), nil)
85+
return
86+
}
87+
writeJSON(w, map[string]any{"total": total})
6488
return
6589
}
6690
if len(parts) >= 4 && parts[2] == "badges" {
6791
badge := core.Badge(parts[3])
68-
err := svc.AwardBadge(r.Context(), user, badge)
69-
writeJSON(w, map[string]any{"ok": err == nil, "err": errString(err)})
92+
if err := core.ValidateBadgeID(badge); err != nil {
93+
writeError(w, http.StatusBadRequest, "invalid_badge", err.Error(), nil)
94+
return
95+
}
96+
if err := svc.AwardBadge(r.Context(), user, badge); err != nil {
97+
writeError(w, http.StatusBadRequest, "invalid_input", err.Error(), nil)
98+
return
99+
}
100+
writeJSON(w, map[string]any{"ok": true})
70101
return
71102
}
72103
case http.MethodGet:
73104
st, err := svc.GetState(r.Context(), user)
74105
if err != nil {
75-
http.Error(w, err.Error(), http.StatusInternalServerError)
106+
writeError(w, http.StatusInternalServerError, "internal", err.Error(), nil)
76107
return
77108
}
78109
writeJSON(w, st)
79110
return
80111
}
81-
http.NotFound(w, r)
112+
writeError(w, http.StatusNotFound, "not_found", "route not found", nil)
82113
})
83114

84115
var handler http.Handler = mux
85116
if opts.AllowCORSOrigin != "" {
86117
handler = withCORS(handler, opts.AllowCORSOrigin)
87118
}
119+
if len(opts.APIKeys) > 0 {
120+
handler = withAPIKeyAuth(handler, opts.APIKeys)
121+
}
122+
if opts.RateLimitEnabled && opts.RateLimitRPM > 0 && opts.RateLimitBurst > 0 {
123+
handler = withRateLimit(handler, opts.RateLimitRPM, opts.RateLimitBurst)
124+
}
88125
return handler
89126
}
90127

@@ -155,11 +192,16 @@ func writeJSON(w http.ResponseWriter, v any) {
155192
_ = json.NewEncoder(w).Encode(v)
156193
}
157194

158-
func errString(err error) any {
159-
if err == nil {
160-
return nil
161-
}
162-
return err.Error()
195+
type apiError struct {
196+
Code string `json:"code"`
197+
Message string `json:"message"`
198+
Details any `json:"details,omitempty"`
199+
}
200+
201+
func writeError(w http.ResponseWriter, status int, code, msg string, details any) {
202+
w.Header().Set("Content-Type", "application/json")
203+
w.WriteHeader(status)
204+
_ = json.NewEncoder(w).Encode(apiError{Code: code, Message: msg, Details: details})
163205
}
164206

165207
// withCORS wraps a handler with a minimal CORS policy.
@@ -176,3 +218,107 @@ func withCORS(next http.Handler, origin string) http.Handler {
176218
next.ServeHTTP(w, r)
177219
})
178220
}
221+
222+
// withAPIKeyAuth enforces a shared API key list.
223+
func withAPIKeyAuth(next http.Handler, apiKeys []string) http.Handler {
224+
allowed := make(map[string]struct{}, len(apiKeys))
225+
for _, k := range apiKeys {
226+
k = strings.TrimSpace(k)
227+
if k != "" {
228+
allowed[k] = struct{}{}
229+
}
230+
}
231+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
232+
key := extractAPIKey(r)
233+
if key == "" {
234+
writeError(w, http.StatusUnauthorized, "unauthorized", "missing API key", nil)
235+
return
236+
}
237+
if _, ok := allowed[key]; !ok {
238+
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key", nil)
239+
return
240+
}
241+
next.ServeHTTP(w, r)
242+
})
243+
}
244+
245+
// withRateLimit applies a simple token-bucket limiter per client key.
246+
func withRateLimit(next http.Handler, rpm int, burst int) http.Handler {
247+
limiter := newRateLimiter(rpm, burst)
248+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
249+
key := clientKey(r)
250+
if !limiter.allow(key) {
251+
writeError(w, http.StatusTooManyRequests, "rate_limited", "too many requests", nil)
252+
return
253+
}
254+
next.ServeHTTP(w, r)
255+
})
256+
}
257+
258+
func extractAPIKey(r *http.Request) string {
259+
auth := r.Header.Get("Authorization")
260+
if strings.HasPrefix(strings.ToLower(auth), "bearer ") {
261+
return strings.TrimSpace(auth[7:])
262+
}
263+
if key := r.Header.Get("X-API-Key"); key != "" {
264+
return key
265+
}
266+
return ""
267+
}
268+
269+
// clientKey uses API key if present, otherwise remote IP.
270+
func clientKey(r *http.Request) string {
271+
if key := extractAPIKey(r); key != "" {
272+
return key
273+
}
274+
host, _, err := net.SplitHostPort(r.RemoteAddr)
275+
if err != nil {
276+
return r.RemoteAddr
277+
}
278+
return host
279+
}
280+
281+
type rateLimiter struct {
282+
rpm float64
283+
burst float64
284+
mu sync.Mutex
285+
b map[string]*bucket
286+
}
287+
288+
type bucket struct {
289+
tokens float64
290+
last time.Time
291+
}
292+
293+
func newRateLimiter(rpm, burst int) *rateLimiter {
294+
return &rateLimiter{
295+
rpm: float64(rpm),
296+
burst: float64(burst),
297+
b: make(map[string]*bucket),
298+
}
299+
}
300+
301+
func (l *rateLimiter) allow(key string) bool {
302+
now := time.Now()
303+
l.mu.Lock()
304+
defer l.mu.Unlock()
305+
306+
b, ok := l.b[key]
307+
if !ok {
308+
l.b[key] = &bucket{tokens: l.burst - 1, last: now}
309+
return true
310+
}
311+
312+
elapsed := now.Sub(b.last).Minutes()
313+
b.tokens += elapsed * l.rpm
314+
if b.tokens > l.burst {
315+
b.tokens = l.burst
316+
}
317+
if b.tokens < 1 {
318+
b.last = now
319+
return false
320+
}
321+
b.tokens--
322+
b.last = now
323+
return true
324+
}

cmd/gamifykit-server/app.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ func provideService(hub *realtime.Hub, storage engine.Storage) *engine.GamifySer
6262

6363
func provideHandler(svc *engine.GamifyService, hub *realtime.Hub, cfg *config.Config) http.Handler {
6464
return httpapi.NewMux(svc, hub, httpapi.Options{
65-
PathPrefix: cfg.Server.PathPrefix,
66-
AllowCORSOrigin: cfg.Server.CORSOrigin,
65+
PathPrefix: cfg.Server.PathPrefix,
66+
AllowCORSOrigin: cfg.Server.CORSOrigin,
67+
APIKeys: cfg.Security.APIKeys,
68+
RateLimitEnabled: cfg.Security.EnableRateLimit,
69+
RateLimitRPM: cfg.Security.RateLimit.RequestsPerMinute,
70+
RateLimitBurst: cfg.Security.RateLimit.BurstSize,
6771
})
6872
}
6973

config/config.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ type ServerConfig struct {
6060

6161
// StorageConfig holds storage adapter configuration
6262
type StorageConfig struct {
63-
Adapter string `json:"adapter" env:"GAMIFYKIT_STORAGE_ADAPTER"`
64-
Redis redis.Config `json:"redis,omitempty"`
65-
SQL sqlx.Config `json:"sql,omitempty"`
66-
File FileConfig `json:"file,omitempty"`
63+
Adapter string `json:"adapter" env:"GAMIFYKIT_STORAGE_ADAPTER"`
64+
Redis redis.Config `json:"redis,omitempty"`
65+
SQL sqlx.Config `json:"sql,omitempty"`
66+
File FileConfig `json:"file,omitempty"`
6767
}
6868

6969
// FileConfig holds JSON file storage configuration
@@ -89,8 +89,9 @@ type MetricsConfig struct {
8989

9090
// SecurityConfig holds security-related configuration
9191
type SecurityConfig struct {
92-
EnableRateLimit bool `json:"enable_rate_limit" env:"GAMIFYKIT_SECURITY_RATE_LIMIT_ENABLED"`
92+
EnableRateLimit bool `json:"enable_rate_limit" env:"GAMIFYKIT_SECURITY_RATE_LIMIT_ENABLED"`
9393
RateLimit RateLimitConfig `json:"rate_limit,omitempty"`
94+
APIKeys []string `json:"api_keys,omitempty" env:"GAMIFYKIT_SECURITY_API_KEYS"`
9495
}
9596

9697
// RateLimitConfig holds rate limiting configuration
@@ -100,6 +101,28 @@ type RateLimitConfig struct {
100101
CleanupInterval time.Duration `json:"cleanup_interval" env:"GAMIFYKIT_SECURITY_RATE_LIMIT_CLEANUP"`
101102
}
102103

104+
// Validate validates security settings.
105+
func (s SecurityConfig) Validate() error {
106+
var errs []string
107+
if s.EnableRateLimit {
108+
if s.RateLimit.RequestsPerMinute <= 0 {
109+
errs = append(errs, "rate_limit.requests_per_minute must be > 0 when rate limiting is enabled")
110+
}
111+
if s.RateLimit.BurstSize <= 0 {
112+
errs = append(errs, "rate_limit.burst_size must be > 0 when rate limiting is enabled")
113+
}
114+
}
115+
for i, key := range s.APIKeys {
116+
if strings.TrimSpace(key) == "" {
117+
errs = append(errs, fmt.Sprintf("api_keys[%d] is empty", i))
118+
}
119+
}
120+
if len(errs) > 0 {
121+
return errors.New(strings.Join(errs, "; "))
122+
}
123+
return nil
124+
}
125+
103126
// Load loads configuration from environment variables and validates it
104127
func Load() (*Config, error) {
105128
cfg := DefaultConfig()
@@ -214,6 +237,7 @@ func DefaultConfig() *Config {
214237
BurstSize: 10,
215238
CleanupInterval: 5 * time.Minute,
216239
},
240+
APIKeys: []string{},
217241
},
218242
}
219243
}
@@ -247,6 +271,11 @@ func (c *Config) Validate() error {
247271
errs = append(errs, fmt.Sprintf("metrics config: %v", err))
248272
}
249273

274+
// Validate security config
275+
if err := c.Security.Validate(); err != nil {
276+
errs = append(errs, fmt.Sprintf("security config: %v", err))
277+
}
278+
250279
if len(errs) > 0 {
251280
return errors.New(strings.Join(errs, "; "))
252281
}

0 commit comments

Comments
 (0)