diff --git a/cache/map.go b/cache/map.go index 2e43191..405fbaf 100644 --- a/cache/map.go +++ b/cache/map.go @@ -5,11 +5,12 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" ) type MapCache[K comparable, V any] struct { - data map[K]mapCacheStruct[V] + data atomic.Value mutex *sync.Mutex cacheFunc func(...any) (V, error) batchCacheFn func(...any) (map[K]V, error) @@ -17,7 +18,9 @@ type MapCache[K comparable, V any] struct { } func NewMapCache[K comparable, V any](expireTime time.Duration) *MapCache[K, V] { - return &MapCache[K, V]{expireTime: expireTime} + var v atomic.Value + v.Store(make(map[K]mapCacheStruct[V])) + return &MapCache[K, V]{expireTime: expireTime, data: v} } type mapCacheStruct[T any] struct { @@ -31,7 +34,7 @@ func (m *MapCache[K, V]) SetCacheFunc(fn func(...any) (V, error)) { } func (m *MapCache[K, V]) GetSetTime(k K) (t time.Time) { - r, ok := m.data[k] + r, ok := m.data.Load().(map[K]mapCacheStruct[V])[k] if ok { t = r.setTime } @@ -58,19 +61,23 @@ func (m *MapCache[K, V]) setCacheFn(fn func(...any) (map[K]V, error)) { } func NewMapCacheByFn[K comparable, V any](fn func(...any) (V, error), expireTime time.Duration) *MapCache[K, V] { + var d atomic.Value + d.Store(make(map[K]mapCacheStruct[V])) return &MapCache[K, V]{ mutex: &sync.Mutex{}, cacheFunc: fn, expireTime: expireTime, - data: make(map[K]mapCacheStruct[V]), + data: d, } } func NewMapCacheByBatchFn[K comparable, V any](fn func(...any) (map[K]V, error), expireTime time.Duration) *MapCache[K, V] { + var d atomic.Value + d.Store(make(map[K]mapCacheStruct[V])) r := &MapCache[K, V]{ mutex: &sync.Mutex{}, batchCacheFn: fn, expireTime: expireTime, - data: make(map[K]mapCacheStruct[V]), + data: d, } r.setCacheFn(fn) return r @@ -79,11 +86,13 @@ func NewMapCacheByBatchFn[K comparable, V any](fn func(...any) (map[K]V, error), func (m *MapCache[K, V]) Flush() { m.mutex.Lock() defer m.mutex.Unlock() - m.data = make(map[K]mapCacheStruct[V]) + var d atomic.Value + d.Store(make(map[K]mapCacheStruct[V])) + m.data = d } func (m *MapCache[K, V]) Get(k K) V { - return m.data[k].data + return m.data.Load().(map[K]mapCacheStruct[V])[k].data } func (m *MapCache[K, V]) Set(k K, v V) { @@ -107,23 +116,26 @@ func (m *MapCache[K, V]) SetByBatchFn(params ...any) error { } func (m *MapCache[K, V]) set(k K, v V) { - data, ok := m.data[k] + d, ok := m.data.Load().(map[K]mapCacheStruct[V]) t := time.Now() + data := d[k] if !ok { data.data = v data.setTime = t data.incr++ - m.data[k] = data } else { - m.data[k] = mapCacheStruct[V]{ + data = mapCacheStruct[V]{ data: v, setTime: t, } } + d[k] = data + m.data.Store(d) } func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duration, params ...any) (V, error) { - data, ok := m.data[key] + d := m.data.Load().(map[K]mapCacheStruct[V]) + data, ok := d[key] if !ok { data = mapCacheStruct[V]{} } @@ -134,11 +146,12 @@ func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duratio if !ok || (ok && m.expireTime >= 0 && expired) { t := data.incr call := func() { - m.mutex.Lock() - defer m.mutex.Unlock() - if data.incr > t { + tmp, o := m.data.Load().(map[K]mapCacheStruct[V])[key] + if o && tmp.incr > t { return } + m.mutex.Lock() + defer m.mutex.Unlock() r, er := m.cacheFunc(params...) if err != nil { err = er @@ -146,8 +159,9 @@ func (m *MapCache[K, V]) GetCache(c context.Context, key K, timeout time.Duratio } data.setTime = time.Now() data.data = r - m.data[key] = data data.incr++ + d[key] = data + m.data.Store(d) } if timeout > 0 { ctx, cancel := context.WithTimeout(c, timeout) @@ -175,8 +189,9 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. var res []V t := 0 now := time.Duration(time.Now().UnixNano()) + data := m.data.Load().(map[K]mapCacheStruct[V]) for _, k := range key { - d, ok := m.data[k] + d, ok := data[k] if !ok { needFlush = append(needFlush, k) continue @@ -195,7 +210,7 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. defer m.mutex.Unlock() tt := 0 for _, dd := range needFlush { - if ddd, ok := m.data[dd]; ok { + if ddd, ok := data[dd]; ok { tt = tt + ddd.incr } } @@ -229,7 +244,7 @@ func (m *MapCache[K, V]) GetCacheBatch(c context.Context, key []K, timeout time. } } for _, k := range key { - d := m.data[k] + d := data[k] res = append(res, d.data) } return res, err @@ -239,9 +254,11 @@ func (m *MapCache[K, V]) ClearExpired() { now := time.Duration(time.Now().UnixNano()) m.mutex.Lock() defer m.mutex.Unlock() - for k, v := range m.data { + data := m.data.Load().(map[K]mapCacheStruct[V]) + for k, v := range data { if now > time.Duration(v.setTime.UnixNano())+m.expireTime { - delete(m.data, k) + delete(data, k) } } + m.data.Store(data) } diff --git a/helper/func.go b/helper/func.go index 0cd5d28..8cf13b3 100644 --- a/helper/func.go +++ b/helper/func.go @@ -5,11 +5,17 @@ import ( "fmt" "github.com/dlclark/regexp2" "io" + "math/rand" "reflect" "regexp" "strings" ) +type IntNumber interface { + ~int | ~int64 | ~int32 | ~int8 | ~int16 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 +} + func ToAny[T any](v T) any { return v } @@ -33,7 +39,7 @@ func StructColumn[T any, M any](arr []M, field string) (r []T) { return } -func RangeSlice[T ~int | ~uint | ~int64 | ~int8 | ~int16 | ~int32 | ~uint64](start, end, step T) []T { +func RangeSlice[T IntNumber](start, end, step T) []T { if step == 0 { panic("step can't be 0") } @@ -250,3 +256,8 @@ func SliceToMap[K comparable, V, T any](arr []V, fn func(V) (K, T), isCoverPrev } return m } + +func RandNum[T IntNumber](start, end T) T { + end++ + return T(rand.Int63n(int64(end-start))) + start +} diff --git a/helper/func_test.go b/helper/func_test.go index 11eba36..2267caa 100644 --- a/helper/func_test.go +++ b/helper/func_test.go @@ -578,3 +578,32 @@ func TestSimpleSliceToMap(t *testing.T) { }) } } + +func TestRandNum(t *testing.T) { + type args struct { + start int + end int + } + tests := []struct { + name string + args args + }{ + { + name: "t1", + args: args{ + start: 1, + end: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := 0; i < 100; i++ { + got := RandNum(tt.args.start, tt.args.end) + if got > tt.args.end || got < tt.args.start { + t.Errorf("RandNum() = %v, range error", got) + } + } + }) + } +} diff --git a/main.go b/main.go index 6d9b4d3..1ea84fe 100644 --- a/main.go +++ b/main.go @@ -8,10 +8,12 @@ import ( "github/fthvgb1/wp-go/plugins" "github/fthvgb1/wp-go/route" "github/fthvgb1/wp-go/vars" + "math/rand" "time" ) func init() { + rand.Seed(time.Now().UnixNano()) err := vars.InitConfig() if err != nil { panic(err) diff --git a/middleware/flowLimit.go b/middleware/flowLimit.go index 4e051d7..3ccefbc 100644 --- a/middleware/flowLimit.go +++ b/middleware/flowLimit.go @@ -2,32 +2,14 @@ package middleware import ( "github.com/gin-gonic/gin" - "math/rand" "net/http" "strings" - "sync" "sync/atomic" "time" ) -type IpLimitMap struct { - mux *sync.Mutex - m map[string]*int64 - singleIpSearchNum int64 -} - -func FlowLimit(maxRequestSleepNum, maxRequestNum, singleIpSearchNum int64, sleepTime []time.Duration) func(ctx *gin.Context) { +func FlowLimit(maxRequestSleepNum, maxRequestNum int64, sleepTime []time.Duration) func(ctx *gin.Context) { var flow int64 - rand.Seed(time.Now().UnixNano()) - randFn := func(start, end time.Duration) time.Duration { - end++ - return time.Duration(rand.Intn(int(end-start)) + int(start)) - } - m := IpLimitMap{ - mux: &sync.Mutex{}, - m: make(map[string]*int64), - singleIpSearchNum: singleIpSearchNum, - } statPath := map[string]struct{}{ "wp-includes": {}, "wp-content": {}, @@ -40,58 +22,19 @@ func FlowLimit(maxRequestSleepNum, maxRequestNum, singleIpSearchNum int64, sleep c.Next() return } - s := false - ip := c.ClientIP() - defer m.searchLimit(false, c, ip, f, &s) - if m.searchLimit(true, c, ip, f, &s) { - c.Abort() - return - } + atomic.AddInt64(&flow, 1) defer func() { atomic.AddInt64(&flow, -1) }() if flow >= maxRequestSleepNum && flow <= maxRequestNum { - t := randFn(sleepTime[0], sleepTime[1]) - time.Sleep(t) + //t := helper.RandNum(sleepTime[0], sleepTime[1]) + //time.Sleep(t) } else if flow > maxRequestNum { c.String(http.StatusForbidden, "请求太多了,服务器君表示压力山大==!, 请稍后访问") c.Abort() - return } c.Next() - } } - -func (m *IpLimitMap) searchLimit(start bool, c *gin.Context, ip string, f []string, s *bool) (isForbid bool) { - if f[0] == "" && c.Query("s") != "" { - if start { - i, ok := m.m[ip] - if !ok { - m.mux.Lock() - i = new(int64) - m.m[ip] = i - m.mux.Unlock() - } - if m.singleIpSearchNum > 0 && *i >= m.singleIpSearchNum { - isForbid = true - return - } - *s = true - atomic.AddInt64(i, 1) - return - } - i, ok := m.m[ip] - if ok && *s && *i > 0 { - atomic.AddInt64(i, -1) - if *i == 0 { - m.mux.Lock() - delete(m.m, ip) - m.mux.Unlock() - } - } - } - return -} diff --git a/middleware/iplimit.go b/middleware/iplimit.go new file mode 100644 index 0000000..ca416b2 --- /dev/null +++ b/middleware/iplimit.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "net/http" + "sync" + "sync/atomic" +) + +type IpLimitMap struct { + mux *sync.Mutex + m map[string]*int64 + limitNum int64 +} + +func IpLimit(num int64) func(ctx *gin.Context) { + m := IpLimitMap{ + mux: &sync.Mutex{}, + m: make(map[string]*int64), + limitNum: num, + } + + return func(c *gin.Context) { + ip := c.ClientIP() + s := false + defer func() { + i, ok := m.m[ip] + if ok && s && *i > 0 { + //time.Sleep(time.Second * 3) + atomic.AddInt64(i, -1) + if *i == 0 { + m.mux.Lock() + delete(m.m, ip) + m.mux.Unlock() + } + } + }() + i, ok := m.m[ip] + if !ok { + m.mux.Lock() + i = new(int64) + m.m[ip] = i + m.mux.Unlock() + } + if m.limitNum > 0 && *i >= m.limitNum { + c.Status(http.StatusForbidden) + c.Abort() + return + } + s = true + atomic.AddInt64(i, 1) + } +} diff --git a/models/model.go b/models/model.go index 1b50190..6e47cd6 100644 --- a/models/model.go +++ b/models/model.go @@ -229,7 +229,7 @@ func SimplePagination[T Model](where ParseWhere, fields, group string, page, pag return } -func FindOneById[T Model, I ~int | ~uint64 | ~int64 | ~int32](id I) (T, error) { +func FindOneById[T Model, I helper.IntNumber](id I) (T, error) { var r T sql := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) err := db.Db.Get(&r, sql, id) diff --git a/route/route.go b/route/route.go index 09cf0ef..7107051 100644 --- a/route/route.go +++ b/route/route.go @@ -39,7 +39,8 @@ func SetupRouter() *gin.Engine { r.Use( middleware.ValidateServerNames(), gin.Logger(), - middleware.FlowLimit(vars.Conf.MaxRequestSleepNum, vars.Conf.MaxRequestNum, vars.Conf.SingleIpSearchNum, vars.Conf.SleepTime), + middleware.FlowLimit(vars.Conf.MaxRequestSleepNum, vars.Conf.MaxRequestNum, vars.Conf.SleepTime), + middleware.IpLimit(10), gin.Recovery(), middleware.SetStaticFileCache, )