Skip to content
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import "github.com/hertz-contrib/cache"
package main

import (
"cache"
"cache/persist"
"context"
"fmt"
"net/http"
Expand All @@ -45,6 +43,8 @@ import (

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down Expand Up @@ -84,15 +84,15 @@ func main() {
package main

import (
"cache"
"cache/persist"
"context"
"net/http"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/go-redis/redis/v8"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down
8 changes: 4 additions & 4 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ import "github.com/hertz-contrib/cache"
package main

import (
"cache"
"cache/persist"
"context"
"fmt"
"net/http"
Expand All @@ -44,6 +42,8 @@ import (

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down Expand Up @@ -83,15 +83,15 @@ func main() {
package main

import (
"cache"
"cache/persist"
"context"
"net/http"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/go-redis/redis/v8"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down
116 changes: 75 additions & 41 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
package cache

import (
"cache/persist"
"context"
"encoding/gob"
"errors"
Expand All @@ -55,6 +54,7 @@ import (
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/hertz-contrib/cache/persist"
"golang.org/x/sync/singleflight"
)

Expand All @@ -77,9 +77,11 @@ type GetCacheStrategyByRequest func(ctx context.Context, c *app.RequestContext)
const (
errMissingCacheStrategy = "[CACHE] cache strategy is nil"
getCacheErrorFormat = "[CACHE] get cache error: %s, cache key: %s"
setCacheKeyErrorFormat = "[CACHE] set cache key error"
setCacheKeyErrorFormat = "[CACHE] set cache key error: %s, cache key: %s"
getRequestUriIgnoreQueryOrderErrorFormat = "[CACHE] getRequestUriIgnoreQueryOrder error: %s"
writeResponseErrorFormat = "[CACHE] write response error: %s"
singleFlightErrorFormat = "[CACHE] call the function in-flight error: %s"
fallbackCacheKeyFormat = "[CACHE] Fallback to default cache key: %s"
)

// NewCache user must pass getCacheKey to describe the way to generate cache key
Expand Down Expand Up @@ -151,7 +153,7 @@ func newCache(
}

inFlight := false
rawRespCache, _, _ := sfGroup.Do(cacheKey, func() (interface{}, error) {
rawRespCache, err, _ := sfGroup.Do(cacheKey, func() (interface{}, error) {
if options.singleFlightForgetTimeout > 0 {
forgetTimer := time.AfterFunc(options.singleFlightForgetTimeout, func() {
sfGroup.Forget(cacheKey)
Expand All @@ -176,43 +178,83 @@ func newCache(
return respCache, nil
})

if err != nil {
hlog.CtxErrorf(ctx, singleFlightErrorFormat, err)
}

if !inFlight {
replyWithCache(ctx, c, options, rawRespCache.(*ResponseCache))
options.shareSingleFlightCallback(ctx, c)
}
}
}

// NewCacheByRequestURI a shortcut function for caching response by uri
func NewCacheByRequestURI(defaultCacheStore persist.CacheStore, defaultExpire time.Duration, opts ...Option) app.HandlerFunc {
options := newOptions(opts...)
// KeyStrategy defines the interface for cache key generation strategies.
type KeyStrategy interface {
GenerateKey(c *app.RequestContext) (string, error)
}

var cacheStrategy GetCacheStrategyByRequest
if options.ignoreQueryOrder {
cacheStrategy = func(ctx context.Context, c *app.RequestContext) (bool, Strategy) {
newUri, err := getRequestUriIgnoreQueryOrder(c.Request.URI().String())
if err != nil {
hlog.CtxErrorf(ctx, getRequestUriIgnoreQueryOrderErrorFormat, err)
newUri = c.Request.URI().String()
}
// ByURI implements KeyStrategy using the request URI.
type ByURI struct{}

return true, Strategy{
CacheKey: newUri,
}
func (s *ByURI) GenerateKey(c *app.RequestContext) (string, error) {
return string(c.Request.RequestURI()), nil
}

// ByURIWithIgnoreQueryOrder implements KeyStrategy using the request URI with ordered query parameters.
type ByURIWithIgnoreQueryOrder struct{}

func (s *ByURIWithIgnoreQueryOrder) GenerateKey(c *app.RequestContext) (string, error) {
return getRequestUriIgnoreQueryOrder(string(c.Request.RequestURI()))
}

// ByPath implements KeyStrategy using the request path.
type ByPath struct{}

func (s *ByPath) GenerateKey(c *app.RequestContext) (string, error) {
return b2s(c.Request.Path()), nil
}

// NewCacheByKeyStrategy is a shortcut function for caching responses based on configurable key generation strategies.
func NewCacheByKeyStrategy(defaultCacheStore persist.CacheStore, defaultExpire time.Duration, strategy KeyStrategy, opts ...Option) app.HandlerFunc {
cacheStrategy := func(ctx context.Context, c *app.RequestContext) (bool, Strategy) {
cacheKey, err := strategy.GenerateKey(c)
if err != nil {
hlog.CtxErrorf(ctx, getRequestUriIgnoreQueryOrderErrorFormat, err)
cacheKey = string(c.Request.RequestURI())
hlog.CtxErrorf(ctx, fallbackCacheKeyFormat, err)
}
} else {
cacheStrategy = func(ctx context.Context, c *app.RequestContext) (bool, Strategy) {
return true, Strategy{
CacheKey: c.Request.URI().String(),
}
return true, Strategy{
CacheKey: cacheKey,
}
}

options.getCacheStrategyByRequest = cacheStrategy
var options []Option
options = append(options, WithCacheStrategyByRequest(cacheStrategy))
options = append(options, opts...)

return newCache(defaultCacheStore, defaultExpire, options)
return NewCache(defaultCacheStore, defaultExpire, options...)
}

// NewCacheByRequestURI a shortcut function for caching response by uri.
func NewCacheByRequestURI(store persist.CacheStore, duration time.Duration, opts ...Option) app.HandlerFunc {
strategy := &ByURI{}
return NewCacheByKeyStrategy(store, duration, strategy, opts...)
}

// NewCacheByRequestURIWithIgnoreQueryOrder a shortcut function for caching response by uri and ignore query param order.
func NewCacheByRequestURIWithIgnoreQueryOrder(store persist.CacheStore, duration time.Duration, opts ...Option) app.HandlerFunc {
strategy := &ByURIWithIgnoreQueryOrder{}
return NewCacheByKeyStrategy(store, duration, strategy, opts...)
}

// NewCacheByRequestPath a shortcut function for caching response by url path, means will discard the query params.
func NewCacheByRequestPath(store persist.CacheStore, duration time.Duration, opts ...Option) app.HandlerFunc {
strategy := &ByPath{}
return NewCacheByKeyStrategy(store, duration, strategy, opts...)
}

// getRequestUriIgnoreQueryOrder returns a URI with query parameters sorted alphabetically by key and value.
func getRequestUriIgnoreQueryOrder(requestURI string) (string, error) {
parsedUrl, err := url.ParseRequestURI(requestURI)
if err != nil {
Expand Down Expand Up @@ -242,17 +284,6 @@ func getRequestUriIgnoreQueryOrder(requestURI string) (string, error) {
return parsedUrl.Path + "?" + strings.Join(queryVals, "&"), nil
}

// NewCacheByRequestPath a shortcut function for caching response by url path, means will discard the query params
func NewCacheByRequestPath(defaultCacheStore persist.CacheStore, defaultExpire time.Duration, opts ...Option) app.HandlerFunc {
opts = append(opts, WithCacheStrategyByRequest(func(ctx context.Context, c *app.RequestContext) (bool, Strategy) {
return true, Strategy{
CacheKey: b2s(c.Request.Path()),
}
}))

return NewCache(defaultCacheStore, defaultExpire, opts...)
}

func init() {
gob.Register(&ResponseCache{})
}
Expand All @@ -266,16 +297,19 @@ type ResponseCache struct {

func (c *ResponseCache) fillWithCacheWriter(cacheWriter *responseCacheWriter, withoutHeader bool) {
c.Status = cacheWriter.StatusCode()
c.Data = cacheWriter.Body()
body := cacheWriter.Body()
buf := make([]byte, len(body))
copy(buf, body)
c.Data = buf
if !withoutHeader {
c.Header = make(map[string][]string)
for _, val := range cacheWriter.Header.GetHeaders() {
if c.Header.Values(b2s(val.GetKey())) != nil {
c.Header.Add(b2s(val.GetKey()), b2s(val.GetValue()))
cacheWriter.Header.VisitAll(func(key, value []byte) {
if c.Header.Get(b2s(key)) != "" {
c.Header.Add(b2s(key), b2s(value))
} else {
c.Header.Set(b2s(val.GetKey()), b2s(val.GetValue()))
c.Header.Set(b2s(key), b2s(value))
}
}
})
}
}

Expand Down
70 changes: 69 additions & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
package cache

import (
"cache/persist"
"context"
"fmt"
"io"
"math/rand"
"net/http"
"sync"
Expand All @@ -52,10 +52,12 @@ import (
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/config"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/ut"
"github.com/cloudwego/hertz/pkg/route"
"github.com/hertz-contrib/cache/persist"
)

func hertzHandler(middleware app.HandlerFunc, withRand bool) *route.Engine {
Expand Down Expand Up @@ -245,3 +247,69 @@ func TestPrefixKey(t *testing.T) {
w2 := ut.PerformRequest(handler, "GET", requestPath, nil)
assert.NotEqual(t, w1.Body, w2.Body)
}

func TestNewCache_Memory(t *testing.T) {
h := server.New(
server.WithHostPorts("127.0.0.1:9233"))
original := map[string][]byte{
"/tmp-cache/ping1": []byte("{\"data\":{\"num\":1111111111}}"),
"/tmp-cache/ping2": []byte("{\"data\":{\"num\":2222222222222222222}}"),
"/tmp-cache/ping3": []byte("{\"data\":{\"num\":3333333333333333333333333333}}"),
}
h.Use(NewCache(persist.NewMemoryStore(time.Second), 3*time.Second,
WithCacheStrategyByRequest(func(ctx context.Context, c *app.RequestContext) (bool, Strategy) {
return true, Strategy{
CacheKey: c.Request.URI().String(),
CacheDuration: 5 * time.Second,
}
})))
h.GET("/tmp-cache/*path", func(ctx context.Context, c *app.RequestContext) {
if data, ok := original[string(c.Request.Path())]; ok {
_, _ = c.Response.BodyWriter().Write(data)
return
}
})
go h.Spin()

tests := []struct {
want []byte
url string
}{
{
want: original["/tmp-cache/ping1"],
url: "http://127.0.0.1:9233/tmp-cache/ping1",
},
{
want: original["/tmp-cache/ping2"],
url: "http://127.0.0.1:9233/tmp-cache/ping2",
},
{
want: original["/tmp-cache/ping3"],
url: "http://127.0.0.1:9233/tmp-cache/ping3",
},
}

for i := 0; i < 10; i++ {
for _, tt := range tests {
t.Run("cache data", func(t *testing.T) {
resp, err := http.Get(tt.url)
assert.Nil(t, err)
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
got := body
assert.DeepEqual(t, string(tt.want), string(got))
})
}
}
}

func TestCacheByURIWithIgnoreQueryOrder(t *testing.T) {
memoryStore := persist.NewMemoryStore(1 * time.Minute)
cacheURIMiddleware := NewCacheByRequestURIWithIgnoreQueryOrder(memoryStore, 3*time.Second)
handler := hertzHandler(cacheURIMiddleware, false)

w1 := ut.PerformRequest(handler, "GET", "/cache?uid=u1&b=2&a=1", nil)
w2 := ut.PerformRequest(handler, "GET", "/cache?a=1&uid=u1&b=2", nil)

assert.DeepEqual(t, w1.Body, w2.Body)
}
4 changes: 2 additions & 2 deletions example/memory/memory_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
package main

import (
"cache"
"cache/persist"
"context"
"net/http"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down
4 changes: 2 additions & 2 deletions example/options/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
package main

import (
"cache"
"cache/persist"
"context"
"fmt"
"net/http"
Expand All @@ -51,6 +49,8 @@ import (

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/cache"
"github.com/hertz-contrib/cache/persist"
)

func main() {
Expand Down
Loading