From 274610be8b604672cfb1e2c1b4e341b6ac992444 Mon Sep 17 00:00:00 2001 From: xing Date: Mon, 17 Mar 2025 21:23:39 +0800 Subject: [PATCH] optimize code and add ip per limit middleware --- app/middleware/iplimit.go | 215 ++++++++++++++++++------ app/plugins/devexample/plugintt/make.sh | 2 +- 2 files changed, 165 insertions(+), 52 deletions(-) mode change 100644 => 100755 app/middleware/iplimit.go diff --git a/app/middleware/iplimit.go b/app/middleware/iplimit.go old mode 100644 new mode 100755 index 12ef767..103c54c --- a/app/middleware/iplimit.go +++ b/app/middleware/iplimit.go @@ -1,79 +1,192 @@ package middleware import ( + "github.com/fthvgb1/wp-go/helper/number" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/gin-gonic/gin" "net/http" + "strings" "sync" "sync/atomic" + "time" ) -type ipLimitMap struct { - mux *sync.RWMutex - m map[string]*int64 - limitNum *int64 - clearNum *int64 +type LimitMap[K comparable] struct { + Mux *sync.RWMutex + Map map[K]*int64 + LimitNum *int64 + ClearNum *int64 } -func IpLimit(num int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { - m := ipLimitMap{ - mux: &sync.RWMutex{}, - m: make(map[string]*int64), - limitNum: new(int64), - clearNum: new(int64), +type FlowLimits[K comparable] struct { + GetKeyFn func(ctx *gin.Context) K + LimitedFns func(ctx *gin.Context) + DeferClearFn func(c *gin.Context, m LimitMap[K], k K, v *int64) + AddFn func(c *gin.Context, m LimitMap[K], k K, v *int64) +} + +func NewFlowLimits[K comparable](getKeyFn func(ctx *gin.Context) K, + limitedFns func(ctx *gin.Context), + deferClearFn func(c *gin.Context, m LimitMap[K], k K, v *int64), + addFns ...func(c *gin.Context, m LimitMap[K], k K, v *int64), +) *FlowLimits[K] { + + f := &FlowLimits[K]{ + GetKeyFn: getKeyFn, + LimitedFns: limitedFns, + DeferClearFn: deferClearFn, } + fn := f.Adds + if len(addFns) > 0 { + fn = addFns[0] + } + f.AddFn = fn + return f +} + +func (f FlowLimits[K]) GetKey(c *gin.Context) K { + return f.GetKeyFn(c) +} + +func (f FlowLimits[K]) Limit(c *gin.Context) { + f.LimitedFns(c) +} +func (f FlowLimits[K]) DeferClear(c *gin.Context, m LimitMap[K], k K, v *int64) { + f.DeferClearFn(c, m, k, v) +} + +func (f FlowLimits[K]) Add(c *gin.Context, m LimitMap[K], k K, v *int64) { + f.AddFn(c, m, k, v) +} + +func (f FlowLimits[K]) Adds(_ *gin.Context, _ LimitMap[K], _ K, v *int64) { + atomic.AddInt64(v, 1) +} + +type MapFlowLimit[K comparable] interface { + GetKey(c *gin.Context) K + Limit(c *gin.Context) + DeferClear(c *gin.Context, m LimitMap[K], k K, v *int64) + Add(c *gin.Context, m LimitMap[K], k K, v *int64) +} + +func CustomFlowLimit[K comparable](a MapFlowLimit[K], maxRequestNum int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { + m := LimitMap[K]{ + Mux: &sync.RWMutex{}, + Map: make(map[K]*int64), + LimitNum: new(int64), + ClearNum: new(int64), + } + atomic.StoreInt64(m.LimitNum, maxRequestNum) + if len(clearNum) > 0 { + atomic.StoreInt64(m.ClearNum, clearNum[0]) + } + fn := func(num int64, clearNum ...int64) { - atomic.StoreInt64(m.limitNum, num) + atomic.StoreInt64(m.LimitNum, num) if len(clearNum) > 0 { - atomic.StoreInt64(m.clearNum, clearNum[0]) + atomic.StoreInt64(m.ClearNum, clearNum[0]) } } - fn(num, clearNum...) - return func(c *gin.Context) { - if atomic.LoadInt64(m.limitNum) <= 0 { + if atomic.LoadInt64(m.LimitNum) <= 0 { c.Next() return } - ip := c.ClientIP() - m.mux.RLock() - i, ok := m.m[ip] - m.mux.RUnlock() - + key := a.GetKey(c) + m.Mux.RLock() + i, ok := m.Map[key] + m.Mux.RUnlock() if !ok { - m.mux.Lock() + m.Mux.Lock() i = new(int64) - m.m[ip] = i - m.mux.Unlock() + m.Map[key] = i + m.Mux.Unlock() } - - defer func() { - atomic.AddInt64(i, -1) - if atomic.LoadInt64(i) <= 0 { - cNum := int(atomic.LoadInt64(m.clearNum)) - if cNum <= 0 { - m.mux.Lock() - delete(m.m, ip) - m.mux.Unlock() - return - } - - m.mux.RLock() - l := len(m.m) - m.mux.RUnlock() - if l < cNum { - m.mux.Lock() - delete(m.m, ip) - m.mux.Unlock() - } - } - }() - - if atomic.LoadInt64(i) >= atomic.LoadInt64(m.limitNum) { - c.String(http.StatusForbidden, "请求太多了,服务器君表示压力山大==!, 请稍后访问") - c.Abort() + a.Add(c, m, key, i) + defer a.DeferClear(c, m, key, i) + if atomic.LoadInt64(i) > atomic.LoadInt64(m.LimitNum) { + a.Limit(c) return } - atomic.AddInt64(i, 1) c.Next() }, fn } + +func IpLimitClear[K comparable](_ *gin.Context, m LimitMap[K], key K, i *int64) { + atomic.AddInt64(i, -1) + if atomic.LoadInt64(i) <= 0 { + cNum := int(atomic.LoadInt64(m.ClearNum)) + if cNum <= 0 { + m.Mux.Lock() + delete(m.Map, key) + m.Mux.Unlock() + return + } + + m.Mux.RLock() + l := len(m.Map) + m.Mux.RUnlock() + if l < cNum { + m.Mux.Lock() + for k, v := range m.Map { + if atomic.LoadInt64(v) < 1 { + delete(m.Map, k) + } + } + m.Mux.Unlock() + } + } +} + +func ToManyRequest(messages ...string) func(c *gin.Context) { + message := "请求太多了,服务器君表示压力山大==!, 请稍后访问" + if len(messages) > 0 { + message = messages[0] + } + return func(c *gin.Context) { + c.String(http.StatusForbidden, message) + c.Abort() + } +} + +func IpLimit(num int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { + a := NewFlowLimits(func(c *gin.Context) string { + return c.ClientIP() + }, ToManyRequest(), IpLimitClear) + return CustomFlowLimit[string](a, num, clearNum...) +} + +func IpMinuteLimit(num int64, clearNum ...int64) (func(ctx *gin.Context), func(int64, ...int64)) { + total := new(int64) + a := NewFlowLimits(func(c *gin.Context) string { + return str.Join(c.ClientIP(), "|", time.Now().Format("2006-01-02 15:04")) + }, + ToManyRequest(), + IpMinuteLimitDeferFn(total), + func(c *gin.Context, m LimitMap[string], k string, v *int64) { + atomic.AddInt64(v, 1) + atomic.AddInt64(total, 1) + }, + ) + + return CustomFlowLimit(a, num, clearNum...) +} + +func IpMinuteLimitDeferFn(total *int64) func(_ *gin.Context, m LimitMap[string], k string, _ *int64) { + return func(_ *gin.Context, m LimitMap[string], k string, _ *int64) { + atomic.AddInt64(total, -1) + cNum := number.Min(int(atomic.LoadInt64(m.ClearNum)), 1) + minu := strings.Split(k, "|")[1] + if int(atomic.LoadInt64(total)) < cNum { + m.Mux.Lock() + for key := range m.Map { + t := strings.Split(key, "|")[1] + if minu != t { + delete(m.Map, key) + } + } + m.Mux.Unlock() + } + } +} diff --git a/app/plugins/devexample/plugintt/make.sh b/app/plugins/devexample/plugintt/make.sh index 27a751e..4243ea3 100644 --- a/app/plugins/devexample/plugintt/make.sh +++ b/app/plugins/devexample/plugintt/make.sh @@ -1,4 +1,4 @@ -#/bin/bash +#!/bin/bash # copy plugintt to other dir and remove .dev suffix # note the go version and build tool flag must same to server build