diff --git a/middleware/flowLimit.go b/middleware/flowLimit.go index 91b6907..e7ef1c2 100644 --- a/middleware/flowLimit.go +++ b/middleware/flowLimit.go @@ -13,7 +13,7 @@ import ( type IpLimitMap struct { mux *sync.Mutex - m map[string]int + m map[string]*int64 } func FlowLimit() func(ctx *gin.Context) { @@ -25,7 +25,7 @@ func FlowLimit() func(ctx *gin.Context) { } m := IpLimitMap{ mux: &sync.Mutex{}, - m: make(map[string]int), + m: make(map[string]*int64), } statPath := map[string]struct{}{ "wp-includes": {}, @@ -39,56 +39,54 @@ func FlowLimit() func(ctx *gin.Context) { c.Next() return } + s := false ip := c.ClientIP() - if m.searchLimit(true, c, ip, f) { + defer m.searchLimit(false, c, ip, f, &s) + if m.searchLimit(true, c, ip, f, &s) { c.Abort() return } atomic.AddInt64(&flow, 1) + defer func() { + atomic.AddInt64(&flow, -1) + }() if flow >= vars.Conf.MaxRequestSleepNum && flow <= vars.Conf.MaxRequestNum { t := randFn(vars.Conf.SleepTime[0], vars.Conf.SleepTime[1]) time.Sleep(t) } else if flow > vars.Conf.MaxRequestNum { c.String(http.StatusForbidden, "请求太多了,服务器君压力山大中==!, 请稍后访问") c.Abort() - atomic.AddInt64(&flow, -1) - m.searchLimit(false, c, ip, f) + return } - c.Next() - m.searchLimit(false, c, ip, f) - atomic.AddInt64(&flow, -1) + } } -func (m *IpLimitMap) set(k string, n int) { - m.mux.Lock() - defer m.mux.Unlock() - m.m[k] = n -} - -func (m *IpLimitMap) searchLimit(a bool, c *gin.Context, ip string, f []string) (isForbid bool) { - +func (m *IpLimitMap) searchLimit(start bool, c *gin.Context, ip string, f []string, s *bool) (isForbid bool) { if f[0] == "" && c.Query("s") != "" { - if a { + if start { i, ok := m.m[ip] - if ok { - num := vars.Conf.SingleIpSearchNum - if num < 1 { - num = 10 - } - if i > num { - return true - } - } else { - i = 0 + num := vars.Conf.SingleIpSearchNum + if !ok { + m.mux.Lock() + i = new(int64) + m.m[ip] = i + m.mux.Unlock() } - i++ - m.set(ip, i) - } else { - m.set(ip, m.m[ip]-1) - if m.m[ip] == 0 { + if num > 0 && *i >= num { + isForbid = true + return + } + *s = true + atomic.AddInt64(i, 1) + return + } + i, ok := m.m[ip] + if ok && *s && *i > 0 { + atomic.AddInt64(i, -1) + if *i == 0 { m.mux.Lock() delete(m.m, ip) m.mux.Unlock() diff --git a/route/route.go b/route/route.go index c9e350d..c337630 100644 --- a/route/route.go +++ b/route/route.go @@ -27,7 +27,7 @@ func SetupRouter() *gin.Engine { return t.Format("2006年 01月 02日") }, }).SetTemplate() - r.Use(gin.Logger(), gin.Recovery(), middleware.FlowLimit(), middleware.SetStaticFileCache) + r.Use(gin.Logger(), middleware.FlowLimit(), gin.Recovery(), middleware.SetStaticFileCache) //gzip 因为一般会用nginx做反代时自动使用gzip,所以go这边本身可以不用 /*r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{ "/wp-includes/", "/wp-content/", diff --git a/vars/config.go b/vars/config.go index 1a7fa58..f1956c9 100644 --- a/vars/config.go +++ b/vars/config.go @@ -27,7 +27,7 @@ type Config struct { MaxRequestSleepNum int64 `yaml:"maxRequestSleepNum"` SleepTime []time.Duration `yaml:"sleepTime"` MaxRequestNum int64 `yaml:"maxRequestNum"` - SingleIpSearchNum int `yaml:"singleIpSearchNum"` + SingleIpSearchNum int64 `yaml:"singleIpSearchNum"` } type Mysql struct {