diff --git a/app/cmd/cachemanager/flush.go b/app/cmd/cachemanager/flush.go index 8d57617..bd6f267 100644 --- a/app/cmd/cachemanager/flush.go +++ b/app/cmd/cachemanager/flush.go @@ -2,21 +2,29 @@ package cachemanager import ( "context" + "errors" "github.com/fthvgb1/wp-go/cache" + str "github.com/fthvgb1/wp-go/helper/strings" + "github.com/fthvgb1/wp-go/safety" "time" ) var ctx = context.Background() +var mapFlush = safety.NewMap[string, func(any)]() +var getSingleFn = safety.NewMap[string, func(context.Context, any, time.Duration, ...any) (any, error)]() +var getBatchFn = safety.NewMap[string, func(context.Context, any, time.Duration, ...any) (any, error)]() +var anyFlush = safety.NewMap[string, func()]() + type flush interface { Flush(ctx context.Context) } -type clear interface { +type clearExpired interface { ClearExpired(ctx context.Context) } -var clears []clear +var clears []clearExpired var flushes []flush @@ -26,23 +34,108 @@ func Flush() { } } -func MapCacheBy[K comparable, V any](fn func(...any) (V, error), expireTime time.Duration) *cache.MapCache[K, V] { - m := cache.NewMemoryMapCacheByFn[K, V](fn, expireTime) +func FlushMapVal[T any](name string, key T) { + v, ok := mapFlush.Load(name) + if !ok { + return + } + v(key) +} + +func FlushAnyVal(name ...string) { + for _, s := range name { + v, ok := anyFlush.Load(s) + if ok { + v() + } + } +} + +func pushFlushMap[K comparable, V any](m *cache.MapCache[K, V], args ...any) { + name := parseArgs(args...) + if name != "" { + anyFlush.Store(name, func() { + m.Flush(ctx) + }) + mapFlush.Store(name, func(a any) { + k, ok := a.(K) + if ok { + m.Delete(ctx, k) + } + }) + getSingleFn.Store(name, func(ct context.Context, k any, t time.Duration, a ...any) (any, error) { + kk, ok := k.(K) + if !ok { + return nil, errors.New(str.Join("cache ", name, " key type err")) + } + return m.GetCache(ct, kk, t, a...) + }) + getBatchFn.Store(name, func(ct context.Context, k any, t time.Duration, a ...any) (any, error) { + kk, ok := k.([]K) + if !ok { + return nil, errors.New(str.Join("cache ", name, " key type err")) + } + return m.GetCacheBatch(ct, kk, t, a...) + }) + FlushPush() + } +} + +func Get[T, K any](name string, ct context.Context, key K, timeout time.Duration, params ...any) (r T, err error) { + v, ok := getSingleFn.Load(name) + if !ok { + err = errors.New(str.Join("cache ", name, " doesn't exist")) + return + } + vv, err := v(ct, key, timeout, params...) + if err != nil { + return r, err + } + r = vv.(T) + return +} +func GetMultiple[T, K any](name string, ct context.Context, key []K, timeout time.Duration, params ...any) (r []T, err error) { + v, ok := getBatchFn.Load(name) + if !ok { + err = errors.New(str.Join("cache ", name, " doesn't exist")) + return + } + vv, err := v(ct, key, timeout, params...) + if err != nil { + return r, err + } + r = vv.([]T) + return +} + +func parseArgs(args ...any) string { + var name string + for _, arg := range args { + v, ok := arg.(string) + if ok { + name = v + } + } + return name +} + +func NewMapCache[K comparable, V any](data cache.Cache[K, V], batchFn cache.MapBatchFn[K, V], + fn cache.MapSingleFn[K, V], expireTime time.Duration, args ...any) *cache.MapCache[K, V] { + m := cache.NewMapCache[K, V](data, fn, batchFn, expireTime) + pushFlushMap(m, args...) FlushPush(m) ClearPush(m) return m } -func MapBatchCacheBy[K comparable, V any](fn func(...any) (map[K]V, error), expireTime time.Duration) *cache.MapCache[K, V] { - m := cache.NewMemoryMapCacheByBatchFn[K, V](fn, expireTime) - FlushPush(m) - ClearPush(m) - return m +func NewMemoryMapCache[K comparable, V any](batchFn cache.MapBatchFn[K, V], + fn cache.MapSingleFn[K, V], expireTime time.Duration, args ...any) *cache.MapCache[K, V] { + return NewMapCache[K, V](cache.NewMemoryMapCache[K, V](), batchFn, fn, expireTime, args...) } func FlushPush(f ...flush) { flushes = append(flushes, f...) } -func ClearPush(c ...clear) { +func ClearPush(c ...clearExpired) { clears = append(clears, c...) } diff --git a/app/pkg/cache/cache.go b/app/pkg/cache/cache.go index cf37edf..f8aa08e 100644 --- a/app/pkg/cache/cache.go +++ b/app/pkg/cache/cache.go @@ -27,10 +27,7 @@ var postListIdsCache *cache.MapCache[string, dao.PostIds] var searchPostIdsCache *cache.MapCache[string, dao.PostIds] var maxPostIdCache *cache.VarCache[uint64] -var usersCache *cache.MapCache[uint64, models.Users] var usersNameCache *cache.MapCache[string, models.Users] -var commentsCache *cache.MapCache[uint64, models.Comments] - var feedCache *cache.VarCache[[]string] var postFeedCache *cache.MapCache[string, string] @@ -44,43 +41,43 @@ var allUsernameCache *cache.VarCache[map[string]struct{}] func InitActionsCommonCache() { c := config.GetConfig() - searchPostIdsCache = cachemanager.MapCacheBy[string](dao.SearchPostIds, c.CacheTime.SearchPostCacheTime) + searchPostIdsCache = cachemanager.NewMemoryMapCache(nil, dao.SearchPostIds, c.CacheTime.SearchPostCacheTime, "searchPostIds") - postListIdsCache = cachemanager.MapCacheBy[string](dao.SearchPostIds, c.CacheTime.PostListCacheTime) + postListIdsCache = cachemanager.NewMemoryMapCache(nil, dao.SearchPostIds, c.CacheTime.PostListCacheTime, "listPostIds") - monthPostsCache = cachemanager.MapCacheBy[string](dao.MonthPost, c.CacheTime.MonthPostCacheTime) + monthPostsCache = cachemanager.NewMemoryMapCache(nil, dao.MonthPost, c.CacheTime.MonthPostCacheTime, "monthPostIds") - postContextCache = cachemanager.MapCacheBy[uint64](dao.GetPostContext, c.CacheTime.ContextPostCacheTime) + postContextCache = cachemanager.NewMemoryMapCache(nil, dao.GetPostContext, c.CacheTime.ContextPostCacheTime, "postContext") - postsCache = cachemanager.MapBatchCacheBy(dao.GetPostsByIds, c.CacheTime.PostDataCacheTime) + postsCache = cachemanager.NewMemoryMapCache(dao.GetPostsByIds, nil, c.CacheTime.PostDataCacheTime, "postData") - postMetaCache = cachemanager.MapBatchCacheBy(dao.GetPostMetaByPostIds, c.CacheTime.PostDataCacheTime) + postMetaCache = cachemanager.NewMemoryMapCache(dao.GetPostMetaByPostIds, nil, c.CacheTime.PostDataCacheTime, "postMetaData") - categoryAndTagsCaches = cachemanager.MapCacheBy[string](dao.CategoriesAndTags, c.CacheTime.CategoryCacheTime) + categoryAndTagsCaches = cachemanager.NewMemoryMapCache(nil, dao.CategoriesAndTags, c.CacheTime.CategoryCacheTime, "categoryAndTagsData") recentPostsCaches = cache.NewVarCache(dao.RecentPosts, c.CacheTime.RecentPostCacheTime) recentCommentsCaches = cache.NewVarCache(dao.RecentComments, c.CacheTime.RecentCommentsCacheTime) - postCommentCaches = cachemanager.MapCacheBy[uint64](dao.PostComments, c.CacheTime.PostCommentsCacheTime) + postCommentCaches = cachemanager.NewMemoryMapCache(nil, dao.PostComments, c.CacheTime.PostCommentsCacheTime, "postCommentIds") maxPostIdCache = cache.NewVarCache(dao.GetMaxPostId, c.CacheTime.MaxPostIdCacheTime) - usersCache = cachemanager.MapCacheBy[uint64](dao.GetUserById, c.CacheTime.UserInfoCacheTime) + cachemanager.NewMemoryMapCache(nil, dao.GetUserById, c.CacheTime.UserInfoCacheTime, "userData") - usersNameCache = cachemanager.MapCacheBy[string](dao.GetUserByName, c.CacheTime.UserInfoCacheTime) + usersNameCache = cachemanager.NewMemoryMapCache(nil, dao.GetUserByName, c.CacheTime.UserInfoCacheTime, "usernameMapToUserData") - commentsCache = cachemanager.MapBatchCacheBy(dao.GetCommentByIds, c.CacheTime.CommentsCacheTime) + cachemanager.NewMemoryMapCache(dao.GetCommentByIds, nil, c.CacheTime.CommentsCacheTime, "commentData") allUsernameCache = cache.NewVarCache(dao.AllUsername, c.CacheTime.UserInfoCacheTime) feedCache = cache.NewVarCache(feed, time.Hour) - postFeedCache = cachemanager.MapCacheBy[string](postFeed, time.Hour) + postFeedCache = cachemanager.NewMemoryMapCache(nil, postFeed, time.Hour, "postFeed") commentsFeedCache = cache.NewVarCache(commentsFeed, time.Hour) - newCommentCache = cachemanager.MapCacheBy[string, string](nil, 15*time.Minute) + newCommentCache = cachemanager.NewMemoryMapCache[string, string](nil, nil, 15*time.Minute, "feed-NewComment") InitFeed() } diff --git a/app/pkg/cache/comments.go b/app/pkg/cache/comments.go index 4d832d1..6d002ba 100644 --- a/app/pkg/cache/comments.go +++ b/app/pkg/cache/comments.go @@ -2,6 +2,7 @@ package cache import ( "context" + "github.com/fthvgb1/wp-go/app/cmd/cachemanager" "github.com/fthvgb1/wp-go/app/pkg/logs" "github.com/fthvgb1/wp-go/app/pkg/models" "github.com/fthvgb1/wp-go/cache" @@ -28,11 +29,11 @@ func PostComments(ctx context.Context, Id uint64) ([]models.Comments, error) { } func GetCommentById(ctx context.Context, id uint64) (models.Comments, error) { - return commentsCache.GetCache(ctx, id, time.Second, ctx, id) + return cachemanager.Get[models.Comments]("commentData", ctx, id, time.Second) } func GetCommentByIds(ctx context.Context, ids []uint64) ([]models.Comments, error) { - return commentsCache.GetCacheBatch(ctx, ids, time.Second, ctx, ids) + return cachemanager.GetMultiple[models.Comments]("commentData", ctx, ids, time.Second) } func NewCommentCache() *cache.MapCache[string, string] { diff --git a/app/pkg/cache/feed.go b/app/pkg/cache/feed.go index 6c742e8..c494b17 100644 --- a/app/pkg/cache/feed.go +++ b/app/pkg/cache/feed.go @@ -1,6 +1,7 @@ package cache import ( + "context" "fmt" "github.com/fthvgb1/wp-go/app/pkg/logs" "github.com/fthvgb1/wp-go/app/pkg/models" @@ -92,9 +93,8 @@ func feed(arg ...any) (xml []string, err error) { return } -func postFeed(arg ...any) (x string, err error) { - c := arg[0].(*gin.Context) - id := arg[1].(string) +func postFeed(c context.Context, id string, arg ...any) (x string, err error) { + id = arg[1].(string) ID := str.ToInteger[uint64](id, 0) maxId, err := GetMaxPostId(c) logs.IfError(err, "get max post id") diff --git a/app/pkg/cache/posts.go b/app/pkg/cache/posts.go index 654acb9..9e2177a 100644 --- a/app/pkg/cache/posts.go +++ b/app/pkg/cache/posts.go @@ -9,16 +9,17 @@ import ( "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" str "github.com/fthvgb1/wp-go/helper/strings" - "github.com/gin-gonic/gin" "time" ) func GetPostById(ctx context.Context, id uint64) (models.Posts, error) { - return postsCache.GetCache(ctx, id, time.Second, ctx, id) + //return cachemanager.Get[models.Posts]("postData", ctx, id, time.Second) + return postsCache.GetCache(ctx, id, time.Second) } func GetPostsByIds(ctx context.Context, ids []uint64) ([]models.Posts, error) { - return postsCache.GetCacheBatch(ctx, ids, time.Second, ctx, ids) + //return cachemanager.GetMultiple[models.Posts]("postData", ctx, ids, time.Second) + return postsCache.GetCacheBatch(ctx, ids, time.Second) } func SearchPost(ctx context.Context, key string, args ...any) (r []models.Posts, total int, err error) { @@ -41,7 +42,7 @@ func PostLists(ctx context.Context, key string, args ...any) (r []models.Posts, return } -func GetMaxPostId(ctx *gin.Context) (uint64, error) { +func GetMaxPostId(ctx context.Context) (uint64, error) { return maxPostIdCache.GetCache(ctx, time.Second, ctx) } diff --git a/app/pkg/cache/users.go b/app/pkg/cache/users.go index 860200f..c546e66 100644 --- a/app/pkg/cache/users.go +++ b/app/pkg/cache/users.go @@ -2,6 +2,7 @@ package cache import ( "context" + "github.com/fthvgb1/wp-go/app/cmd/cachemanager" "github.com/fthvgb1/wp-go/app/pkg/logs" "github.com/fthvgb1/wp-go/app/pkg/models" "github.com/fthvgb1/wp-go/model" @@ -24,7 +25,7 @@ func GetAllUsername(ctx context.Context) (map[string]struct{}, error) { } func GetUserById(ctx context.Context, uid uint64) models.Users { - r, err := usersCache.GetCache(ctx, uid, time.Second, ctx, uid) + r, err := cachemanager.Get[models.Users]("userData", ctx, uid, time.Second) logs.IfError(err, "get user", uid) return r } diff --git a/app/pkg/dao/comments.go b/app/pkg/dao/comments.go index 504e0c5..8b219d4 100644 --- a/app/pkg/dao/comments.go +++ b/app/pkg/dao/comments.go @@ -28,9 +28,7 @@ func RecentComments(a ...any) (r []models.Comments, err error) { // PostComments // param1 context.Context // param2 postId -func PostComments(args ...any) ([]uint64, error) { - ctx := args[0].(context.Context) - postId := args[1].(uint64) +func PostComments(ctx context.Context, postId uint64, _ ...any) ([]uint64, error) { r, err := model.Finds[models.Comments](ctx, model.Conditions( model.Where(model.SqlBuilder{ {"comment_approved", "1"}, @@ -50,9 +48,7 @@ func PostComments(args ...any) ([]uint64, error) { }), err } -func GetCommentByIds(args ...any) (map[uint64]models.Comments, error) { - ctx := args[0].(context.Context) - ids := args[1].([]uint64) +func GetCommentByIds(ctx context.Context, ids []uint64, _ ...any) (map[uint64]models.Comments, error) { m := make(map[uint64]models.Comments) r, err := model.SimpleFind[models.Comments](ctx, model.SqlBuilder{ {"comment_ID", "in", ""}, {"comment_approved", "1"}, diff --git a/app/pkg/dao/common.go b/app/pkg/dao/common.go index 40a34be..04d4925 100644 --- a/app/pkg/dao/common.go +++ b/app/pkg/dao/common.go @@ -21,17 +21,13 @@ type PostContext struct { Next models.Posts } -func CategoriesAndTags(a ...any) (terms []models.TermsMy, err error) { - ctx := a[0].(context.Context) - t, ok := a[1].(string) +func CategoriesAndTags(ctx context.Context, t string, _ ...any) (terms []models.TermsMy, err error) { var in = []any{"category", "post_tag"} - if ok { - switch t { - case constraints.Category: - in = []any{"category"} - case constraints.Tag: - in = []any{"post_tag"} - } + switch t { + case constraints.Category: + in = []any{"category"} + case constraints.Tag: + in = []any{"post_tag"} } w := model.SqlBuilder{ {"tt.taxonomy", "in", ""}, diff --git a/app/pkg/dao/postmeta.go b/app/pkg/dao/postmeta.go index 1e80793..8a3927e 100644 --- a/app/pkg/dao/postmeta.go +++ b/app/pkg/dao/postmeta.go @@ -11,10 +11,8 @@ import ( "strconv" ) -func GetPostMetaByPostIds(args ...any) (r map[uint64]map[string]any, err error) { +func GetPostMetaByPostIds(ctx context.Context, ids []uint64, _ ...any) (r map[uint64]map[string]any, err error) { r = make(map[uint64]map[string]any) - ctx := args[0].(context.Context) - ids := args[1].([]uint64) rr, err := model.Finds[models.PostMeta](ctx, model.Conditions( model.Where(model.SqlBuilder{{"post_id", "in", ""}}), model.In(slice.ToAnySlice(ids)), diff --git a/app/pkg/dao/posts.go b/app/pkg/dao/posts.go index 6338fef..634a43d 100644 --- a/app/pkg/dao/posts.go +++ b/app/pkg/dao/posts.go @@ -3,6 +3,7 @@ package dao import ( "context" "database/sql" + "errors" "fmt" "github.com/fthvgb1/wp-go/app/pkg/models" "github.com/fthvgb1/wp-go/app/pkg/models/relation" @@ -15,10 +16,8 @@ import ( "time" ) -func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { - ctx := a[0].(context.Context) +func GetPostsByIds(ctx context.Context, ids []uint64, _ ...any) (m map[uint64]models.Posts, err error) { m = make(map[uint64]models.Posts) - ids := a[1].([]uint64) q := model.Conditions( model.Where(model.SqlBuilder{{"Id", "in", ""}}), model.Join(model.SqlBuilder{ @@ -99,8 +98,7 @@ func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { return } -func SearchPostIds(args ...any) (ids PostIds, err error) { - ctx := args[0].(context.Context) +func SearchPostIds(ctx context.Context, _ string, args ...any) (ids PostIds, err error) { q := args[1].(*model.QueryCondition) page := args[2].(int) pageSize := args[3].(int) @@ -146,15 +144,14 @@ func RecentPosts(a ...any) (r []models.Posts, err error) { return } -func GetPostContext(arg ...any) (r PostContext, err error) { - ctx := arg[0].(context.Context) +func GetPostContext(ctx context.Context, _ uint64, arg ...any) (r PostContext, err error) { t := arg[1].(time.Time) next, err := model.FirstOne[models.Posts](ctx, model.SqlBuilder{ {"post_date", ">", t.Format("2006-01-02 15:04:05")}, {"post_status", "in", ""}, {"post_type", "post"}, }, "ID,post_title,post_password", nil, []any{"publish"}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = nil } if err != nil { @@ -165,7 +162,7 @@ func GetPostContext(arg ...any) (r PostContext, err error) { {"post_status", "in", ""}, {"post_type", "post"}, }, "ID,post_title", model.SqlBuilder{{"post_date", "desc"}}, []any{"publish"}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = nil } if err != nil { @@ -178,8 +175,7 @@ func GetPostContext(arg ...any) (r PostContext, err error) { return } -func MonthPost(args ...any) (r []uint64, err error) { - ctx := args[0].(context.Context) +func MonthPost(ctx context.Context, _ string, args ...any) (r []uint64, err error) { year, month := args[1].(string), args[2].(string) where := model.SqlBuilder{ {"post_type", "post"}, diff --git a/app/pkg/dao/users.go b/app/pkg/dao/users.go index 9642978..63ea683 100644 --- a/app/pkg/dao/users.go +++ b/app/pkg/dao/users.go @@ -7,9 +7,7 @@ import ( "github.com/fthvgb1/wp-go/model" ) -func GetUserById(a ...any) (r models.Users, err error) { - ctx := a[0].(context.Context) - uid := a[1].(uint64) +func GetUserById(ctx context.Context, uid uint64, _ ...any) (r models.Users, err error) { r, err = model.FindOneById[models.Users](ctx, uid) return } @@ -27,9 +25,7 @@ func AllUsername(a ...any) (map[string]struct{}, error) { }, true), nil } -func GetUserByName(a ...any) (r models.Users, err error) { - u := a[1].(string) - ctx := a[0].(context.Context) +func GetUserByName(ctx context.Context, u string, _ ...any) (r models.Users, err error) { r, err = model.FirstOne[models.Users](ctx, model.SqlBuilder{{ "user_login", u, }}, "*", nil) diff --git a/app/plugins/digest.go b/app/plugins/digest.go index b2147cd..7cb1b67 100644 --- a/app/plugins/digest.go +++ b/app/plugins/digest.go @@ -22,17 +22,15 @@ var more = regexp.MustCompile("") var removeWpBlock = regexp.MustCompile("") func InitDigestCache() { - digestCache = cachemanager.MapCacheBy[uint64](digestRaw, config.GetConfig().CacheTime.DigestCacheTime) + digestCache = cachemanager.NewMemoryMapCache(nil, digestRaw, config.GetConfig().CacheTime.DigestCacheTime, "digestPlugin") } func RemoveWpBlock(s string) string { return removeWpBlock.ReplaceAllString(s, "") } -func digestRaw(arg ...any) (string, error) { - ctx := arg[0].(context.Context) +func digestRaw(ctx context.Context, id uint64, arg ...any) (string, error) { s := arg[1].(string) - id := arg[2].(uint64) limit := arg[3].(int) if limit < 0 { return s, nil diff --git a/cache/map.go b/cache/map.go index 684bcf5..00ea95e 100644 --- a/cache/map.go +++ b/cache/map.go @@ -10,48 +10,79 @@ import ( ) type MapCache[K comparable, V any] struct { - data Cache[K, V] + handle Cache[K, V] mux sync.Mutex - cacheFunc func(...any) (V, error) - batchCacheFn func(...any) (map[K]V, error) + cacheFunc MapSingleFn[K, V] + batchCacheFn MapBatchFn[K, V] expireTime time.Duration } -func (m *MapCache[K, V]) SetCacheFunc(fn func(...any) (V, error)) { +type MapSingleFn[K, V any] func(context.Context, K, ...any) (V, error) +type MapBatchFn[K comparable, V any] func(context.Context, []K, ...any) (map[K]V, error) + +func NewMapCache[K comparable, V any](data Cache[K, V], cacheFunc MapSingleFn[K, V], batchCacheFn MapBatchFn[K, V], expireTime time.Duration) *MapCache[K, V] { + r := &MapCache[K, V]{ + handle: data, + mux: sync.Mutex{}, + cacheFunc: cacheFunc, + batchCacheFn: batchCacheFn, + expireTime: expireTime, + } + if cacheFunc == nil && batchCacheFn != nil { + r.setDefaultCacheFn(batchCacheFn) + } else if batchCacheFn == nil && cacheFunc != nil { + r.SetDefaultBatchFunc(cacheFunc) + } + return r +} + +func (m *MapCache[K, V]) SetDefaultBatchFunc(fn MapSingleFn[K, V]) { + m.batchCacheFn = func(ctx context.Context, ids []K, a ...any) (map[K]V, error) { + var err error + rr := make(map[K]V) + for _, id := range ids { + v, er := fn(ctx, id) + if er != nil { + err = errors.Join(er) + continue + } + rr[id] = v + } + return rr, err + } +} + +func (m *MapCache[K, V]) SetCacheFunc(fn MapSingleFn[K, V]) { m.cacheFunc = fn } +func (m *MapCache[K, V]) GetHandle() Cache[K, V] { + return m.handle +} func (m *MapCache[K, V]) Ttl(ctx context.Context, k K) time.Duration { - return m.data.Ttl(ctx, k, m.expireTime) + return m.handle.Ttl(ctx, k, m.expireTime) } func (m *MapCache[K, V]) GetLastSetTime(ctx context.Context, k K) (t time.Time) { - tt := m.data.Ttl(ctx, k, m.expireTime) + tt := m.handle.Ttl(ctx, k, m.expireTime) if tt <= 0 { return } - return time.Now().Add(m.data.Ttl(ctx, k, m.expireTime)).Add(-m.expireTime) + return time.Now().Add(m.handle.Ttl(ctx, k, m.expireTime)).Add(-m.expireTime) } -func (m *MapCache[K, V]) SetCacheBatchFn(fn func(...any) (map[K]V, error)) { +func (m *MapCache[K, V]) SetCacheBatchFn(fn MapBatchFn[K, V]) { m.batchCacheFn = fn if m.cacheFunc == nil { - m.setCacheFn(fn) + m.setDefaultCacheFn(fn) } } -func (m *MapCache[K, V]) setCacheFn(fn func(...any) (map[K]V, error)) { - m.cacheFunc = func(a ...any) (V, error) { +func (m *MapCache[K, V]) setDefaultCacheFn(fn MapBatchFn[K, V]) { + m.cacheFunc = func(ctx context.Context, k K, a ...any) (V, error) { var err error var r map[K]V - var k K - ctx, ok := a[0].(context.Context) - if ok { - k, ok = a[1].(K) - if ok { - r, err = fn(ctx, []K{k}) - } - } + r, err = fn(ctx, []K{k}, a...) if err != nil { var rr V @@ -61,52 +92,61 @@ func (m *MapCache[K, V]) setCacheFn(fn func(...any) (map[K]V, error)) { } } -func NewMapCacheByFn[K comparable, V any](cacheType Cache[K, V], fn func(...any) (V, error), expireTime time.Duration) *MapCache[K, V] { - return &MapCache[K, V]{ +func NewMapCacheByFn[K comparable, V any](cacheType Cache[K, V], fn MapSingleFn[K, V], expireTime time.Duration) *MapCache[K, V] { + r := &MapCache[K, V]{ mux: sync.Mutex{}, cacheFunc: fn, expireTime: expireTime, - data: cacheType, + handle: cacheType, } + r.SetDefaultBatchFunc(fn) + return r } -func NewMapCacheByBatchFn[K comparable, V any](cacheType Cache[K, V], fn func(...any) (map[K]V, error), expireTime time.Duration) *MapCache[K, V] { +func NewMapCacheByBatchFn[K comparable, V any](cacheType Cache[K, V], fn MapBatchFn[K, V], expireTime time.Duration) *MapCache[K, V] { r := &MapCache[K, V]{ mux: sync.Mutex{}, batchCacheFn: fn, expireTime: expireTime, - data: cacheType, + handle: cacheType, } - r.setCacheFn(fn) + r.setDefaultCacheFn(fn) return r } func (m *MapCache[K, V]) Flush(ctx context.Context) { m.mux.Lock() defer m.mux.Unlock() - m.data.Flush(ctx) + m.handle.Flush(ctx) } func (m *MapCache[K, V]) Get(ctx context.Context, k K) (V, bool) { - return m.data.Get(ctx, k) + return m.handle.Get(ctx, k) } func (m *MapCache[K, V]) Set(ctx context.Context, k K, v V) { - m.data.Set(ctx, k, v, m.expireTime) + m.handle.Set(ctx, k, v, m.expireTime) +} + +func (m *MapCache[K, V]) Delete(ctx context.Context, k K) { + m.handle.Delete(ctx, k) +} +func (m *MapCache[K, V]) Ver(ctx context.Context, k K) int { + return m.handle.Ver(ctx, k) } func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duration, params ...any) (V, error) { - data, ok := m.data.Get(c, key) + data, ok := m.handle.Get(c, key) var err error - if !ok || m.data.Ttl(c, key, m.expireTime) <= 0 { - ver := m.data.Ver(c, key) + if !ok || m.handle.Ttl(c, key, m.expireTime) <= 0 { + ver := m.handle.Ver(c, key) call := func() { m.mux.Lock() defer m.mux.Unlock() - if m.data.Ver(c, key) > ver { - data, _ = m.data.Get(c, key) + if m.handle.Ver(c, key) > ver { + data, _ = m.handle.Get(c, key) return } - data, err = m.cacheFunc(params...) + data, err = m.cacheFunc(c, key, params...) if err != nil { return } @@ -137,10 +177,10 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. var res []V ver := 0 needFlush := slice.FilterAndMap(key, func(k K) (r K, ok bool) { - if _, ok := m.data.Get(c, k); !ok { + if _, ok := m.handle.Get(c, k); !ok { return k, true } - ver += m.data.Ver(c, k) + ver += m.handle.Ver(c, k) return }) @@ -151,14 +191,14 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. defer m.mux.Unlock() vers := slice.Reduce(needFlush, func(t K, r int) int { - return r + m.data.Ver(c, t) + return r + m.handle.Ver(c, t) }, 0) if vers > ver { return } - r, er := m.batchCacheFn(params...) + r, er := m.batchCacheFn(c, key, params...) if err != nil { err = er return @@ -185,7 +225,7 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. } } res = slice.FilterAndMap(key, func(k K) (V, bool) { - return m.data.Get(c, k) + return m.handle.Get(c, k) }) return res, err } @@ -193,5 +233,5 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. func (m *MapCache[K, V]) ClearExpired(ctx context.Context) { m.mux.Lock() defer m.mux.Unlock() - m.data.ClearExpired(ctx, m.expireTime) + m.handle.ClearExpired(ctx, m.expireTime) } diff --git a/cache/map_test.go b/cache/map_test.go index 7fdb3e7..66f8f40 100644 --- a/cache/map_test.go +++ b/cache/map_test.go @@ -12,19 +12,17 @@ import ( ) var ca MapCache[string, string] -var fn func(a ...any) (string, error) -var batchFn func(a ...any) (map[string]string, error) +var fn MapSingleFn[string, string] +var batchFn MapBatchFn[string, string] var ct context.Context func init() { - fn = func(a ...any) (string, error) { - aa := a[1].(string) + fn = func(ctx context.Context, aa string, a ...any) (string, error) { return strings.Repeat(aa, 2), nil } ct = context.Background() - batchFn = func(a ...any) (map[string]string, error) { + batchFn = func(ctx context.Context, arr []string, a ...any) (map[string]string, error) { fmt.Println(a) - arr := a[1].([]string) return slice.FilterAndToMap(arr, func(t string) (string, string, bool) { return t, strings.Repeat(t, 2), true }), nil @@ -293,7 +291,7 @@ func TestMapCache_Set(t *testing.T) { func TestMapCache_SetCacheBatchFn(t *testing.T) { type args[K comparable, V any] struct { - fn func(...any) (map[K]V, error) + fn MapBatchFn[K, V] } type testCase[K comparable, V any] struct { name string @@ -315,19 +313,19 @@ func TestMapCache_SetCacheBatchFn(t *testing.T) { } func TestMapCache_SetCacheFunc(t *testing.T) { - type args[V any] struct { - fn func(...any) (V, error) + type args[K comparable, V any] struct { + fn MapSingleFn[K, V] } type testCase[K comparable, V any] struct { name string m MapCache[K, V] - args args[V] + args args[K, V] } tests := []testCase[string, string]{ { name: "t1", m: ca, - args: args[string]{fn: fn}, + args: args[string, string]{fn: fn}, }, } for _, tt := range tests { @@ -370,7 +368,7 @@ func TestMapCache_Ttl(t *testing.T) { func TestMapCache_setCacheFn(t *testing.T) { type args[K comparable, V any] struct { - fn func(...any) (map[K]V, error) + fn MapBatchFn[K, V] } type testCase[K comparable, V any] struct { name string @@ -387,7 +385,7 @@ func TestMapCache_setCacheFn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ca.cacheFunc = nil - tt.m.setCacheFn(tt.args.fn) + tt.m.setDefaultCacheFn(tt.args.fn) fmt.Println(ca.GetCache(ct, "xx", time.Second, ct, "xx")) }) } diff --git a/cache/memorymapcache.go b/cache/memorymapcache.go index 028d7b5..5ab8dd0 100644 --- a/cache/memorymapcache.go +++ b/cache/memorymapcache.go @@ -11,24 +11,14 @@ type MemoryMapCache[K comparable, V any] struct { *safety.Map[K, mapVal[V]] } -func NewMemoryMapCacheByFn[K comparable, V any](fn func(...any) (V, error), expireTime time.Duration) *MapCache[K, V] { +func NewMemoryMapCacheByFn[K comparable, V any](fn MapSingleFn[K, V], expireTime time.Duration) *MapCache[K, V] { return &MapCache[K, V]{ - data: NewMemoryMapCache[K, V](), + handle: NewMemoryMapCache[K, V](), cacheFunc: fn, expireTime: expireTime, mux: sync.Mutex{}, } } -func NewMemoryMapCacheByBatchFn[K comparable, V any](fn func(...any) (map[K]V, error), expireTime time.Duration) *MapCache[K, V] { - r := &MapCache[K, V]{ - data: NewMemoryMapCache[K, V](), - batchCacheFn: fn, - expireTime: expireTime, - mux: sync.Mutex{}, - } - r.setCacheFn(fn) - return r -} func NewMemoryMapCache[K comparable, V any]() *MemoryMapCache[K, V] { return &MemoryMapCache[K, V]{Map: safety.NewMap[K, mapVal[V]]()}