diff --git a/cache/map.go b/cache/map.go index 2a2e631..158de1b 100644 --- a/cache/map.go +++ b/cache/map.go @@ -131,56 +131,66 @@ func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duratio func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time.Duration, params ...any) ([]V, error) { var res []V ver := 0 - needFlush := slice.FilterAndMap(key, func(k K) (r K, ok bool) { - if _, ok := m.Get(c, k); !ok { - ver += m.Ver(c, k) - return k, true - } - return - }) - - var err error - if len(needFlush) > 0 { - call := func() { - m.mux.Lock() - defer m.mux.Unlock() - - vers := slice.Reduce(needFlush, func(t K, r int) int { - return r + m.Ver(c, t) - }, 0) - - if vers > ver { - return - } - - r, er := m.batchCacheFn(c, key, params...) - if err != nil { - err = er - return - } - for k, v := range r { - m.Set(c, k, v) - } - } - if timeout > 0 { - ctx, cancel := context.WithTimeout(c, timeout) - defer cancel() - done := make(chan struct{}, 1) - go func() { - call() - done <- struct{}{} - }() - select { - case <-ctx.Done(): - err = errors.New(fmt.Sprintf("get cache %v %s", key, ctx.Err().Error())) - case <-done: - } + var needFlush []K + var needIndex = make(map[int]K) + slice.ForEach(key, func(i int, k K) { + v, ok := m.Get(c, k) + var vv V + if ok { + vv = v } else { - call() + needFlush = append(needFlush, k) + ver += m.Ver(c, k) + needIndex[i] = k + } + res = append(res, vv) + }) + if len(needFlush) < 1 { + return res, nil + } + var err error + call := func() { + m.mux.Lock() + defer m.mux.Unlock() + + vers := slice.Reduce(needFlush, func(t K, r int) int { + return r + m.Ver(c, t) + }, 0) + + if vers > ver { + return + } + + r, er := m.batchCacheFn(c, key, params...) + if err != nil { + err = er + return + } + for k, v := range r { + m.Set(c, k, v) + } + } + if timeout > 0 { + ctx, cancel := context.WithTimeout(c, timeout) + defer cancel() + done := make(chan struct{}, 1) + go func() { + call() + done <- struct{}{} + }() + select { + case <-ctx.Done(): + err = errors.New(fmt.Sprintf("get cache %v %s", key, ctx.Err().Error())) + case <-done: + } + } else { + call() + } + for index, k := range needIndex { + v, ok := m.Get(c, k) + if ok { + res[index] = v } } - res = slice.FilterAndMap(key, func(k K) (V, bool) { - return m.Get(c, k) - }) return res, err }