完善
This commit is contained in:
parent
f18bbcd1e6
commit
a6049ff78a
77
cache/map.go
vendored
77
cache/map.go
vendored
|
@ -4,13 +4,13 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/fthvgb1/wp-go/safety"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MapCache[K comparable, V any] struct {
|
||||
data atomic.Value
|
||||
data safety.Map[K, mapCacheStruct[V]]
|
||||
mutex *sync.Mutex
|
||||
cacheFunc func(...any) (V, error)
|
||||
batchCacheFn func(...any) (map[K]V, error)
|
||||
|
@ -18,9 +18,7 @@ type MapCache[K comparable, V any] struct {
|
|||
}
|
||||
|
||||
func NewMapCache[K comparable, V any](expireTime time.Duration) *MapCache[K, V] {
|
||||
var v atomic.Value
|
||||
v.Store(make(map[K]mapCacheStruct[V]))
|
||||
return &MapCache[K, V]{expireTime: expireTime, data: v}
|
||||
return &MapCache[K, V]{expireTime: expireTime}
|
||||
}
|
||||
|
||||
type mapCacheStruct[T any] struct {
|
||||
|
@ -34,7 +32,7 @@ func (m *MapCache[K, V]) SetCacheFunc(fn func(...any) (V, error)) {
|
|||
}
|
||||
|
||||
func (m *MapCache[K, V]) GetSetTime(k K) (t time.Time) {
|
||||
r, ok := m.data.Load().(map[K]mapCacheStruct[V])[k]
|
||||
r, ok := m.data.Load(k)
|
||||
if ok {
|
||||
t = r.setTime
|
||||
}
|
||||
|
@ -61,23 +59,19 @@ func (m *MapCache[K, V]) setCacheFn(fn func(...any) (map[K]V, error)) {
|
|||
}
|
||||
|
||||
func NewMapCacheByFn[K comparable, V any](fn func(...any) (V, error), expireTime time.Duration) *MapCache[K, V] {
|
||||
var d atomic.Value
|
||||
d.Store(make(map[K]mapCacheStruct[V]))
|
||||
return &MapCache[K, V]{
|
||||
mutex: &sync.Mutex{},
|
||||
cacheFunc: fn,
|
||||
expireTime: expireTime,
|
||||
data: d,
|
||||
data: safety.NewMap[K, mapCacheStruct[V]](),
|
||||
}
|
||||
}
|
||||
func NewMapCacheByBatchFn[K comparable, V any](fn func(...any) (map[K]V, error), expireTime time.Duration) *MapCache[K, V] {
|
||||
var d atomic.Value
|
||||
d.Store(make(map[K]mapCacheStruct[V]))
|
||||
r := &MapCache[K, V]{
|
||||
mutex: &sync.Mutex{},
|
||||
batchCacheFn: fn,
|
||||
expireTime: expireTime,
|
||||
data: d,
|
||||
data: safety.NewMap[K, mapCacheStruct[V]](),
|
||||
}
|
||||
r.setCacheFn(fn)
|
||||
return r
|
||||
|
@ -86,13 +80,16 @@ func NewMapCacheByBatchFn[K comparable, V any](fn func(...any) (map[K]V, error),
|
|||
func (m *MapCache[K, V]) Flush() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
var d atomic.Value
|
||||
d.Store(make(map[K]mapCacheStruct[V]))
|
||||
m.data = d
|
||||
m.data = safety.NewMap[K, mapCacheStruct[V]]()
|
||||
}
|
||||
|
||||
func (m *MapCache[K, V]) Get(k K) V {
|
||||
return m.data.Load().(map[K]mapCacheStruct[V])[k].data
|
||||
r, ok := m.data.Load(k)
|
||||
if ok {
|
||||
return r.data
|
||||
}
|
||||
var rr V
|
||||
return rr
|
||||
}
|
||||
|
||||
func (m *MapCache[K, V]) Set(k K, v V) {
|
||||
|
@ -114,26 +111,24 @@ func (m *MapCache[K, V]) SetByBatchFn(params ...any) error {
|
|||
}
|
||||
|
||||
func (m *MapCache[K, V]) set(k K, v V) {
|
||||
d, ok := m.data.Load().(map[K]mapCacheStruct[V])
|
||||
data, ok := m.data.Load(k)
|
||||
t := time.Now()
|
||||
data := d[k]
|
||||
if !ok {
|
||||
data.data = v
|
||||
data.setTime = t
|
||||
data.incr++
|
||||
m.data.Store(k, data)
|
||||
} else {
|
||||
data = mapCacheStruct[V]{
|
||||
m.data.Store(k, mapCacheStruct[V]{
|
||||
data: v,
|
||||
setTime: t,
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
d[k] = data
|
||||
m.data.Store(d)
|
||||
}
|
||||
|
||||
func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duration, params ...any) (V, error) {
|
||||
d := m.data.Load().(map[K]mapCacheStruct[V])
|
||||
data, ok := d[key]
|
||||
data, ok := m.data.Load(key)
|
||||
if !ok {
|
||||
data = mapCacheStruct[V]{}
|
||||
}
|
||||
|
@ -144,12 +139,11 @@ func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duratio
|
|||
if !ok || (ok && m.expireTime >= 0 && expired) {
|
||||
t := data.incr
|
||||
call := func() {
|
||||
tmp, o := m.data.Load().(map[K]mapCacheStruct[V])[key]
|
||||
if o && tmp.incr > t {
|
||||
return
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if data.incr > t {
|
||||
return
|
||||
}
|
||||
r, er := m.cacheFunc(params...)
|
||||
if err != nil {
|
||||
err = er
|
||||
|
@ -157,9 +151,8 @@ func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duratio
|
|||
}
|
||||
data.setTime = time.Now()
|
||||
data.data = r
|
||||
m.data.Store(key, data)
|
||||
data.incr++
|
||||
d[key] = data
|
||||
m.data.Store(d)
|
||||
}
|
||||
if timeout > 0 {
|
||||
ctx, cancel := context.WithTimeout(c, timeout)
|
||||
|
@ -187,9 +180,8 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time.
|
|||
var res []V
|
||||
t := 0
|
||||
now := time.Duration(time.Now().UnixNano())
|
||||
data := m.data.Load().(map[K]mapCacheStruct[V])
|
||||
for _, k := range key {
|
||||
d, ok := data[k]
|
||||
d, ok := m.data.Load(k)
|
||||
if !ok {
|
||||
needFlush = append(needFlush, k)
|
||||
continue
|
||||
|
@ -204,17 +196,17 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time.
|
|||
//todo 这里应该判断下取出的值是否为零值,不过怎么操作呢?
|
||||
if len(needFlush) > 0 {
|
||||
call := func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
tt := 0
|
||||
for _, dd := range needFlush {
|
||||
if ddd, ok := data[dd]; ok {
|
||||
if ddd, ok := m.data.Load(dd); ok {
|
||||
tt = tt + ddd.incr
|
||||
}
|
||||
}
|
||||
if tt > t {
|
||||
return
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
r, er := m.batchCacheFn(params...)
|
||||
if err != nil {
|
||||
err = er
|
||||
|
@ -242,8 +234,10 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time.
|
|||
}
|
||||
}
|
||||
for _, k := range key {
|
||||
d := data[k]
|
||||
res = append(res, d.data)
|
||||
d, ok := m.data.Load(k)
|
||||
if ok {
|
||||
res = append(res, d.data)
|
||||
}
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
@ -252,11 +246,10 @@ func (m *MapCache[K, V]) ClearExpired() {
|
|||
now := time.Duration(time.Now().UnixNano())
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
data := m.data.Load().(map[K]mapCacheStruct[V])
|
||||
for k, v := range data {
|
||||
m.data.Range(func(k K, v mapCacheStruct[V]) bool {
|
||||
if now > time.Duration(v.setTime.UnixNano())+m.expireTime {
|
||||
delete(data, k)
|
||||
m.data.Delete(k)
|
||||
}
|
||||
}
|
||||
m.data.Store(data)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
50
cache/slice.go
vendored
50
cache/slice.go
vendored
|
@ -4,11 +4,16 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github/fthvgb1/wp-go/safety"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SliceCache[T any] struct {
|
||||
v safety.Var[slice[T]]
|
||||
}
|
||||
|
||||
type slice[T any] struct {
|
||||
data []T
|
||||
mutex *sync.Mutex
|
||||
setCacheFunc func(...any) ([]T, error)
|
||||
|
@ -18,45 +23,52 @@ type SliceCache[T any] struct {
|
|||
}
|
||||
|
||||
func (c *SliceCache[T]) SetTime() time.Time {
|
||||
return c.setTime
|
||||
|
||||
return c.v.Load().setTime
|
||||
}
|
||||
|
||||
func NewSliceCache[T any](fun func(...any) ([]T, error), duration time.Duration) *SliceCache[T] {
|
||||
return &SliceCache[T]{
|
||||
mutex: &sync.Mutex{},
|
||||
setCacheFunc: fun,
|
||||
expireTime: duration,
|
||||
v: safety.NewVar(slice[T]{
|
||||
mutex: &sync.Mutex{},
|
||||
setCacheFunc: fun,
|
||||
expireTime: duration,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SliceCache[T]) FlushCache() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.data = nil
|
||||
mu := c.v.Load().mutex
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
c.v.Delete()
|
||||
}
|
||||
|
||||
func (c *SliceCache[T]) GetCache(ctx context.Context, timeout time.Duration, params ...any) ([]T, error) {
|
||||
l := len(c.data)
|
||||
data := c.data
|
||||
v := c.v.Load()
|
||||
l := len(v.data)
|
||||
data := v.data
|
||||
var err error
|
||||
expired := time.Duration(c.setTime.UnixNano())+c.expireTime < time.Duration(time.Now().UnixNano())
|
||||
if l < 1 || (l > 0 && c.expireTime >= 0 && expired) {
|
||||
t := c.incr
|
||||
expired := time.Duration(v.setTime.UnixNano())+v.expireTime < time.Duration(time.Now().UnixNano())
|
||||
if l < 1 || (l > 0 && v.expireTime >= 0 && expired) {
|
||||
t := v.incr
|
||||
call := func() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.incr > t {
|
||||
v := c.v.Load()
|
||||
v.mutex.Lock()
|
||||
defer v.mutex.Unlock()
|
||||
if v.incr > t {
|
||||
return
|
||||
}
|
||||
r, er := c.setCacheFunc(params...)
|
||||
r, er := v.setCacheFunc(params...)
|
||||
if err != nil {
|
||||
err = er
|
||||
return
|
||||
}
|
||||
c.setTime = time.Now()
|
||||
c.data = r
|
||||
v.setTime = time.Now()
|
||||
v.data = r
|
||||
data = r
|
||||
c.incr++
|
||||
v.incr++
|
||||
c.v.Store(v)
|
||||
}
|
||||
if timeout > 0 {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github/fthvgb1/wp-go/helper"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
@ -23,18 +24,20 @@ func FlowLimit(maxRequestSleepNum, maxRequestNum int64, sleepTime []time.Duratio
|
|||
return
|
||||
}
|
||||
|
||||
atomic.AddInt64(&flow, 1)
|
||||
defer func() {
|
||||
atomic.AddInt64(&flow, -1)
|
||||
}()
|
||||
if flow >= maxRequestSleepNum && flow <= maxRequestNum {
|
||||
//t := helper.RandNum(sleepTime[0], sleepTime[1])
|
||||
//time.Sleep(t)
|
||||
} else if flow > maxRequestNum {
|
||||
n := atomic.LoadInt64(&flow)
|
||||
if n >= maxRequestSleepNum && n <= maxRequestNum {
|
||||
t := helper.RandNum(sleepTime[0], sleepTime[1])
|
||||
time.Sleep(t)
|
||||
} else if n > maxRequestNum {
|
||||
c.String(http.StatusForbidden, "请求太多了,服务器君表示压力山大==!, 请稍后访问")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&flow, 1)
|
||||
defer func() {
|
||||
atomic.AddInt64(&flow, -1)
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,14 +8,14 @@ import (
|
|||
)
|
||||
|
||||
type IpLimitMap struct {
|
||||
mux *sync.Mutex
|
||||
mux *sync.RWMutex
|
||||
m map[string]*int64
|
||||
limitNum int64
|
||||
}
|
||||
|
||||
func IpLimit(num int64) func(ctx *gin.Context) {
|
||||
m := IpLimitMap{
|
||||
mux: &sync.Mutex{},
|
||||
mux: &sync.RWMutex{},
|
||||
m: make(map[string]*int64),
|
||||
limitNum: num,
|
||||
}
|
||||
|
@ -23,31 +23,35 @@ func IpLimit(num int64) func(ctx *gin.Context) {
|
|||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
s := false
|
||||
m.mux.RLock()
|
||||
i, ok := m.m[ip]
|
||||
m.mux.RUnlock()
|
||||
defer func() {
|
||||
i, ok := m.m[ip]
|
||||
if ok && s && *i > 0 {
|
||||
//time.Sleep(time.Second * 3)
|
||||
ii := atomic.LoadInt64(i)
|
||||
if s && ii > 0 {
|
||||
atomic.AddInt64(i, -1)
|
||||
if *i == 0 {
|
||||
if atomic.LoadInt64(i) == 0 {
|
||||
m.mux.Lock()
|
||||
delete(m.m, ip)
|
||||
m.mux.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
i, ok := m.m[ip]
|
||||
|
||||
if !ok {
|
||||
m.mux.Lock()
|
||||
i = new(int64)
|
||||
m.m[ip] = i
|
||||
m.mux.Unlock()
|
||||
}
|
||||
if m.limitNum > 0 && *i >= m.limitNum {
|
||||
|
||||
if m.limitNum > 0 && atomic.LoadInt64(i) >= m.limitNum {
|
||||
c.Status(http.StatusForbidden)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
s = true
|
||||
atomic.AddInt64(i, 1)
|
||||
s = true
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
15
middleware/searchlimit.go
Normal file
15
middleware/searchlimit.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func SearchLimit(num int64) func(ctx *gin.Context) {
|
||||
fn := IpLimit(num)
|
||||
return func(c *gin.Context) {
|
||||
if "/" == c.FullPath() && c.Query("s") != "" {
|
||||
fn(c)
|
||||
} else {
|
||||
c.Next()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -39,8 +39,8 @@ func SetupRouter() *gin.Engine {
|
|||
r.Use(
|
||||
middleware.ValidateServerNames(),
|
||||
gin.Logger(),
|
||||
//middleware.FlowLimit(vars.Conf.MaxRequestSleepNum, vars.Conf.MaxRequestNum, vars.Conf.SingleIpSearchNum, vars.Conf.SleepTime),
|
||||
gin.Recovery(),
|
||||
middleware.FlowLimit(vars.Conf.MaxRequestSleepNum, vars.Conf.MaxRequestNum, vars.Conf.SleepTime),
|
||||
middleware.SetStaticFileCache,
|
||||
)
|
||||
//gzip 因为一般会用nginx做反代时自动使用gzip,所以go这边本身可以不用
|
||||
|
@ -58,7 +58,7 @@ func SetupRouter() *gin.Engine {
|
|||
}))
|
||||
store := cookie.NewStore([]byte("secret"))
|
||||
r.Use(sessions.Sessions("go-wp", store))
|
||||
r.GET("/", actions.Index)
|
||||
r.GET("/", middleware.SearchLimit(vars.Conf.SingleIpSearchNum), actions.Index)
|
||||
r.GET("/page/:page", actions.Index)
|
||||
r.GET("/p/category/:category", actions.Index)
|
||||
r.GET("/p/category/:category/page/:page", actions.Index)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package safeMap
|
||||
package safety
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
@ -59,9 +59,8 @@ type Map[K comparable, V any] struct {
|
|||
}
|
||||
|
||||
func NewMap[K comparable, V any]() Map[K, V] {
|
||||
var r V
|
||||
return Map[K, V]{
|
||||
expunged: unsafe.Pointer(&r),
|
||||
expunged: unsafe.Pointer(new(any)),
|
||||
}
|
||||
}
|
||||
|
37
safety/vars.go
Normal file
37
safety/vars.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package safety
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Var[T any] struct {
|
||||
val T
|
||||
p unsafe.Pointer
|
||||
}
|
||||
|
||||
func NewVar[T any](val T) Var[T] {
|
||||
return Var[T]{val: val, p: unsafe.Pointer(&val)}
|
||||
}
|
||||
|
||||
func (r *Var[T]) Load() T {
|
||||
return *(*T)(atomic.LoadPointer(&r.p))
|
||||
}
|
||||
|
||||
func (r *Var[T]) Delete() {
|
||||
for {
|
||||
px := atomic.LoadPointer(&r.p)
|
||||
if atomic.CompareAndSwapPointer(&r.p, px, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Var[T]) Store(v T) {
|
||||
for {
|
||||
px := atomic.LoadPointer(&r.p)
|
||||
if atomic.CompareAndSwapPointer(&r.p, px, unsafe.Pointer(&v)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
41
safety/vars_test.go
Normal file
41
safety/vars_test.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package safety
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestVar_Load(t *testing.T) {
|
||||
type fields struct {
|
||||
val string
|
||||
p unsafe.Pointer
|
||||
}
|
||||
s := ""
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
fields: fields{
|
||||
val: s,
|
||||
p: unsafe.Pointer(&s),
|
||||
},
|
||||
want: "sffs",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Var[string]{
|
||||
val: tt.fields.val,
|
||||
p: tt.fields.p,
|
||||
}
|
||||
r.Store(tt.want)
|
||||
if got := r.Load(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Load() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user