diff --git a/app/middleware/iplimit.go b/app/middleware/iplimit.go index 103c54c..d233095 100755 --- a/app/middleware/iplimit.go +++ b/app/middleware/iplimit.go @@ -1,7 +1,6 @@ package middleware import ( - "github.com/fthvgb1/wp-go/helper/number" str "github.com/fthvgb1/wp-go/helper/strings" "github.com/gin-gonic/gin" "net/http" @@ -21,14 +20,14 @@ type LimitMap[K comparable] struct { type FlowLimits[K comparable] struct { GetKeyFn func(ctx *gin.Context) K LimitedFns func(ctx *gin.Context) - DeferClearFn func(c *gin.Context, m LimitMap[K], k K, v *int64) - AddFn func(c *gin.Context, m LimitMap[K], k K, v *int64) + DeferClearFn func(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) + AddFn func(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) } func NewFlowLimits[K comparable](getKeyFn func(ctx *gin.Context) K, limitedFns func(ctx *gin.Context), - deferClearFn func(c *gin.Context, m LimitMap[K], k K, v *int64), - addFns ...func(c *gin.Context, m LimitMap[K], k K, v *int64), + deferClearFn func(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64), + addFns ...func(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64), ) *FlowLimits[K] { f := &FlowLimits[K]{ @@ -51,23 +50,24 @@ func (f FlowLimits[K]) GetKey(c *gin.Context) K { func (f FlowLimits[K]) Limit(c *gin.Context) { f.LimitedFns(c) } -func (f FlowLimits[K]) DeferClear(c *gin.Context, m LimitMap[K], k K, v *int64) { - f.DeferClearFn(c, m, k, v) +func (f FlowLimits[K]) DeferClear(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) { + f.DeferClearFn(c, m, k, v, currentTotalFlow) } -func (f FlowLimits[K]) Add(c *gin.Context, m LimitMap[K], k K, v *int64) { - f.AddFn(c, m, k, v) +func (f FlowLimits[K]) Add(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) { + f.AddFn(c, m, k, v, currentTotalFlow) } -func (f FlowLimits[K]) Adds(_ *gin.Context, _ LimitMap[K], _ K, v *int64) { +func (f FlowLimits[K]) Adds(_ *gin.Context, _ LimitMap[K], _ K, v, currentTotalFlow *int64) { atomic.AddInt64(v, 1) + atomic.AddInt64(currentTotalFlow, 1) } type MapFlowLimit[K comparable] interface { GetKey(c *gin.Context) K Limit(c *gin.Context) - DeferClear(c *gin.Context, m LimitMap[K], k K, v *int64) - Add(c *gin.Context, m LimitMap[K], k K, v *int64) + DeferClear(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) + Add(c *gin.Context, m LimitMap[K], k K, v, currentTotalFlow *int64) } func CustomFlowLimit[K comparable](a MapFlowLimit[K], maxRequestNum int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { @@ -88,6 +88,7 @@ func CustomFlowLimit[K comparable](a MapFlowLimit[K], maxRequestNum int64, clear atomic.StoreInt64(m.ClearNum, clearNum[0]) } } + currentTotalFlow := new(int64) return func(c *gin.Context) { if atomic.LoadInt64(m.LimitNum) <= 0 { c.Next() @@ -103,8 +104,9 @@ func CustomFlowLimit[K comparable](a MapFlowLimit[K], maxRequestNum int64, clear m.Map[key] = i m.Mux.Unlock() } - a.Add(c, m, key, i) - defer a.DeferClear(c, m, key, i) + a.Add(c, m, key, i, currentTotalFlow) + defer a.DeferClear(c, m, key, i, currentTotalFlow) + defer atomic.AddInt64(currentTotalFlow, -1) if atomic.LoadInt64(i) > atomic.LoadInt64(m.LimitNum) { a.Limit(c) return @@ -113,7 +115,7 @@ func CustomFlowLimit[K comparable](a MapFlowLimit[K], maxRequestNum int64, clear }, fn } -func IpLimitClear[K comparable](_ *gin.Context, m LimitMap[K], key K, i *int64) { +func IpLimitClear[K comparable](_ *gin.Context, m LimitMap[K], key K, i, currentTotalFlow *int64) { atomic.AddInt64(i, -1) if atomic.LoadInt64(i) <= 0 { cNum := int(atomic.LoadInt64(m.ClearNum)) @@ -124,10 +126,7 @@ func IpLimitClear[K comparable](_ *gin.Context, m LimitMap[K], key K, i *int64) return } - m.Mux.RLock() - l := len(m.Map) - m.Mux.RUnlock() - if l < cNum { + if int(atomic.LoadInt64(currentTotalFlow)) <= cNum { m.Mux.Lock() for k, v := range m.Map { if atomic.LoadInt64(v) < 1 { @@ -158,35 +157,25 @@ func IpLimit(num int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, } func IpMinuteLimit(num int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { - total := new(int64) a := NewFlowLimits(func(c *gin.Context) string { return str.Join(c.ClientIP(), "|", time.Now().Format("2006-01-02 15:04")) - }, - ToManyRequest(), - IpMinuteLimitDeferFn(total), - func(c *gin.Context, m LimitMap[string], k string, v *int64) { - atomic.AddInt64(v, 1) - atomic.AddInt64(total, 1) - }, + }, ToManyRequest(), IpMinuteLimitDeferFn, ) return CustomFlowLimit(a, num, clearNum...) } -func IpMinuteLimitDeferFn(total *int64) func(_ *gin.Context, m LimitMap[string], k string, _ *int64) { - return func(_ *gin.Context, m LimitMap[string], k string, _ *int64) { - atomic.AddInt64(total, -1) - cNum := number.Min(int(atomic.LoadInt64(m.ClearNum)), 1) - minu := strings.Split(k, "|")[1] - if int(atomic.LoadInt64(total)) < cNum { - m.Mux.Lock() - for key := range m.Map { - t := strings.Split(key, "|")[1] - if minu != t { - delete(m.Map, key) - } +func IpMinuteLimitDeferFn(_ *gin.Context, m LimitMap[string], k string, _, currentTotalFlow *int64) { + cNum := min(int(atomic.LoadInt64(m.ClearNum)), 1) + minu := strings.Split(k, "|")[1] + if int(atomic.LoadInt64(currentTotalFlow)) < cNum { + m.Mux.Lock() + for key := range m.Map { + t := strings.Split(key, "|")[1] + if minu != t { + delete(m.Map, key) } - m.Mux.Unlock() } + m.Mux.Unlock() } } diff --git a/app/route/route.go b/app/route/route.go index c396ec6..3eb8388 100644 --- a/app/route/route.go +++ b/app/route/route.go @@ -120,7 +120,7 @@ func SetupRouter() *gin.Engine { r.GET("/feed", actions.SiteFeed) r.GET("/comments/feed", actions.CommentsFeed) commentMiddleWare, _ := middleware.FlowLimit(5, c.SingleIpSearchNum, c.CacheTime.SleepTime) - commentIpMiddleware, _ := middleware.IpLimit(5, 5) + commentIpMiddleware, _ := middleware.IpLimit(5, 2) r.POST("/comment", commentMiddleWare, commentIpMiddleware, actions.PostComment) r.NoRoute(actions.ThemeHook(constraints.NoRoute))