diff --git a/middleware/flowLimit.go b/middleware/flowLimit.go index e8ef405..4e051d7 100644 --- a/middleware/flowLimit.go +++ b/middleware/flowLimit.go @@ -2,7 +2,6 @@ package middleware import ( "github.com/gin-gonic/gin" - "github/fthvgb1/wp-go/vars" "math/rand" "net/http" "strings" @@ -12,11 +11,12 @@ import ( ) type IpLimitMap struct { - mux *sync.Mutex - m map[string]*int64 + mux *sync.Mutex + m map[string]*int64 + singleIpSearchNum int64 } -func FlowLimit() func(ctx *gin.Context) { +func FlowLimit(maxRequestSleepNum, maxRequestNum, singleIpSearchNum int64, sleepTime []time.Duration) func(ctx *gin.Context) { var flow int64 rand.Seed(time.Now().UnixNano()) randFn := func(start, end time.Duration) time.Duration { @@ -24,8 +24,9 @@ func FlowLimit() func(ctx *gin.Context) { return time.Duration(rand.Intn(int(end-start)) + int(start)) } m := IpLimitMap{ - mux: &sync.Mutex{}, - m: make(map[string]*int64), + mux: &sync.Mutex{}, + m: make(map[string]*int64), + singleIpSearchNum: singleIpSearchNum, } statPath := map[string]struct{}{ "wp-includes": {}, @@ -50,10 +51,10 @@ func FlowLimit() func(ctx *gin.Context) { defer func() { atomic.AddInt64(&flow, -1) }() - if flow >= vars.Conf.MaxRequestSleepNum && flow <= vars.Conf.MaxRequestNum { - t := randFn(vars.Conf.SleepTime[0], vars.Conf.SleepTime[1]) + if flow >= maxRequestSleepNum && flow <= maxRequestNum { + t := randFn(sleepTime[0], sleepTime[1]) time.Sleep(t) - } else if flow > vars.Conf.MaxRequestNum { + } else if flow > maxRequestNum { c.String(http.StatusForbidden, "请求太多了,服务器君表示压力山大==!, 请稍后访问") c.Abort() @@ -68,14 +69,13 @@ func (m *IpLimitMap) searchLimit(start bool, c *gin.Context, ip string, f []stri if f[0] == "" && c.Query("s") != "" { if start { i, ok := m.m[ip] - num := vars.Conf.SingleIpSearchNum if !ok { m.mux.Lock() i = new(int64) m.m[ip] = i m.mux.Unlock() } - if num > 0 && *i >= num { + if m.singleIpSearchNum > 0 && *i >= m.singleIpSearchNum { isForbid = true return } diff --git a/route/route.go b/route/route.go index 06780f0..c295474 100644 --- a/route/route.go +++ b/route/route.go @@ -21,6 +21,13 @@ func SetupRouter() *gin.Engine { // Disable Console Color // gin.DisableConsoleColor() r := gin.New() + if len(vars.Conf.TrustIps) > 0 { + err := r.SetTrustedProxies(vars.Conf.TrustIps) + if err != nil { + panic(err) + } + } + r.HTMLRender = templates.NewFsTemplate(template.FuncMap{ "unescaped": func(s string) any { return template.HTML(s) @@ -29,7 +36,7 @@ func SetupRouter() *gin.Engine { return t.Format("2006年 01月 02日") }, }).SetTemplate() - r.Use(gin.Logger(), middleware.FlowLimit(), gin.Recovery(), middleware.SetStaticFileCache) + r.Use(gin.Logger(), middleware.FlowLimit(vars.Conf.MaxRequestSleepNum, vars.Conf.MaxRequestNum, vars.Conf.SingleIpSearchNum, vars.Conf.SleepTime), gin.Recovery(), middleware.SetStaticFileCache) //gzip 因为一般会用nginx做反代时自动使用gzip,所以go这边本身可以不用 if vars.Conf.Gzip { r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{ diff --git a/vars/config.go b/vars/config.go index 7f0e70e..bd19dd4 100644 --- a/vars/config.go +++ b/vars/config.go @@ -33,6 +33,7 @@ type Config struct { CommentsCacheTime time.Duration `yaml:"commentsCacheTime"` Gzip bool `yaml:"gzip"` PostCommentUrl string `yaml:"postCommentUrl"` + TrustIps []string `yaml:"trustIps"` } type Mysql struct {