wp-go/cache/pagination.go
2023-12-10 19:15:49 +08:00

168 lines
4.7 KiB
Go

package cache
import (
"context"
"errors"
"fmt"
"github.com/fthvgb1/wp-go/helper"
"github.com/fthvgb1/wp-go/helper/number"
str "github.com/fthvgb1/wp-go/helper/strings"
"github.com/fthvgb1/wp-go/safety"
"strings"
"time"
)
type Pagination[K comparable, V any] struct {
*MapCache[string, helper.PaginationData[V]]
maxNum func() int
isSwitch *safety.Map[K, bool]
dbFn func(ctx context.Context, k K, page, limit, totalRaw int, a ...any) ([]V, int, error)
localFn func(ctx context.Context, data []V, k K, page, limit int, a ...any) ([]V, int, error)
batchFetchNum func() int
localKeyFn func(K K, a ...any) string
dbKeyFn func(K K, a ...any) string
name string
}
var switchDb = errors.New("switch Db")
type DbFn[K comparable, V any] func(ctx context.Context, k K, page, limit, totalRaw int, a ...any) ([]V, int, error)
type LocalFn[K comparable, V any] func(ctx context.Context, data []V, k K, page, limit int, a ...any) ([]V, int, error)
func NewPagination[K comparable, V any](m *MapCache[string, helper.PaginationData[V]], maxNum func() int, dbFn DbFn[K, V], localFn LocalFn[K, V], dbKeyFn, localKeyFn func(K, ...any) string, batchFetchNum func() int, name string) *Pagination[K, V] {
if dbKeyFn == nil {
dbKeyFn = func(k K, a ...any) string {
s := str.NewBuilder()
for _, v := range append([]any{k}, a...) {
s.Sprintf("%v|", v)
}
return strings.TrimRight(s.String(), "|")
}
}
if localKeyFn == nil {
localKeyFn = func(k K, a ...any) string {
return fmt.Sprintf("%v", k)
}
}
return &Pagination[K, V]{
MapCache: m,
maxNum: maxNum,
isSwitch: safety.NewMap[K, bool](),
dbFn: dbFn,
localFn: localFn,
batchFetchNum: batchFetchNum,
name: name,
dbKeyFn: dbKeyFn,
localKeyFn: localKeyFn,
}
}
func (p *Pagination[K, V]) Pagination(ctx context.Context, timeout time.Duration, k K, page, limit int, a ...any) ([]V, int, error) {
if is, _ := p.isSwitch.Load(k); is {
return p.paginationByDB(ctx, timeout, k, page, limit, 0, a...)
}
data, total, err := p.paginationByLocal(ctx, timeout, k, page, limit, a...)
if err != nil {
if errors.Is(err, switchDb) {
err = nil
return p.paginationByDB(ctx, timeout, k, page, limit, total, a...)
}
return nil, 0, err
}
return data, total, err
}
func (p *Pagination[K, V]) paginationByLocal(ctx context.Context, timeout time.Duration, k K, page, limit int, a ...any) ([]V, int, error) {
key := p.localKeyFn(k)
data, ok := p.Get(ctx, key)
if ok {
if p.increaseUpdate != nil && p.refresh != nil {
dat, err := p.increaseUpdates(ctx, timeout, data, key, a...)
if err != nil {
return nil, 0, err
}
if dat.TotalRaw >= p.maxNum() {
return nil, 0, switchDb
}
data = dat
}
return p.localFn(ctx, data.Data, k, page, limit, a...)
}
p.mux.Lock()
defer p.mux.Unlock()
data, ok = p.Get(ctx, key)
if ok {
return data.Data, data.TotalRaw, nil
}
batchNum := p.batchFetchNum()
da, totalRaw, err := p.fetchDb(ctx, timeout, k, 1, 0, 0, a...)
if err != nil {
return nil, 0, err
}
if totalRaw >= p.maxNum() {
p.isSwitch.Store(k, true)
return nil, totalRaw, switchDb
}
if len(da) < totalRaw {
totalPage := number.DivideCeil(totalRaw, batchNum)
for i := 1; i <= totalPage; i++ {
daa, _, err := p.fetchDb(ctx, timeout, k, i, batchNum, totalRaw, a...)
if err != nil {
return nil, 0, err
}
da = append(da, daa...)
}
data.Data = da
data.TotalRaw = totalRaw
p.Set(ctx, key, data)
}
return p.localFn(ctx, data.Data, k, page, limit, a...)
}
func (p *Pagination[K, V]) paginationByDB(ctx context.Context, timeout time.Duration, k K, page, limit, totalRaw int, a ...any) ([]V, int, error) {
key := p.dbKeyFn(k, append([]any{page, limit}, a...)...)
data, ok := p.Get(ctx, key)
if ok {
return data.Data, data.TotalRaw, nil
}
p.mux.Lock()
defer p.mux.Unlock()
data, ok = p.Get(ctx, key)
if ok {
return data.Data, data.TotalRaw, nil
}
dat, total, err := p.fetchDb(ctx, timeout, k, page, limit, totalRaw, a...)
if err != nil {
return nil, 0, err
}
data.Data, data.TotalRaw = dat, total
p.Set(ctx, key, data)
return data.Data, data.TotalRaw, err
}
func (p *Pagination[K, V]) fetchDb(ctx context.Context, timeout time.Duration, k K, page, limit, totalRaw int, a ...any) ([]V, int, error) {
var data helper.PaginationData[V]
var err error
fn := func() {
da, total, er := p.dbFn(ctx, k, page, limit, totalRaw, a...)
if er != nil {
err = er
return
}
data.Data = da
data.TotalRaw = total
}
if timeout > 0 {
er := helper.RunFnWithTimeout(ctx, timeout, fn, fmt.Sprintf("fetch %s-[%v]-page[%d]-limit[%d] from db fail", p.name, k, page, limit))
if err == nil && er != nil {
err = er
}
} else {
fn()
}
return data.Data, data.TotalRaw, err
}