diff --git a/go.mod b/go.mod index 14aba28..27035a5 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/jinzhu/gorm v1.9.16 github.com/json-iterator/go v1.1.11 // indirect + github.com/juju/ratelimit v1.0.1 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/magiconair/properties v1.8.5 // indirect github.com/mailru/easyjson v0.7.7 // indirect diff --git a/internal/middleware/context_timeout.go b/internal/middleware/context_timeout.go new file mode 100644 index 0000000..9ee09ec --- /dev/null +++ b/internal/middleware/context_timeout.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "context" + "github.com/gin-gonic/gin" + "time" +) + +func ContextTimeout(t time.Duration) func(c *gin.Context) { + return func(c *gin.Context) { + ctx, cancel := context.WithTimeout(c.Request.Context(), t) + defer cancel() + + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} diff --git a/internal/middleware/limiter.go b/internal/middleware/limiter.go new file mode 100644 index 0000000..875b172 --- /dev/null +++ b/internal/middleware/limiter.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "blog/pkg/app" + "blog/pkg/errorcode" + "blog/pkg/limiter" + "github.com/gin-gonic/gin" +) + +func RateLimiter(l limiter.LimiterIface) gin.HandlerFunc { + return func(c *gin.Context) { + key := l.Key(c) + if bucket, ok := l.GetBucket(key); ok { + count := bucket.TakeAvailable(1) + if count == 0 { + response := app.NewResponse(c) + response.ToErrorResponse(errorcode.TooManyRequests) + c.Abort() + return + } + } + + c.Next() + } +} diff --git a/internal/routess/router.go b/internal/routess/router.go index 9528f3d..8633a9c 100644 --- a/internal/routess/router.go +++ b/internal/routess/router.go @@ -6,19 +6,33 @@ import ( "blog/internal/middleware" "blog/internal/routess/api" v1 "blog/internal/routess/api/v1" + "blog/pkg/limiter" "github.com/gin-gonic/gin" swaggerFiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" "net/http" + "time" ) +var methodLimiters = limiter.NewMethodLimiter().AddBuckets(limiter.LimiterBucketRule{ + Key: "/auth", + FillInterval: time.Second, + Capacity: 10, + Quantum: 10, +}) + func NewRouter() *gin.Engine { r := gin.New() - r.Use(gin.Logger()) - r.Use(gin.Recovery()) + if global.ServerSetting.RunMode == "debug" { + r.Use(gin.Logger()) + r.Use(gin.Recovery()) + } else { + r.Use(middleware.AccessLog()) + r.Use(middleware.Recovery()) + } r.Use(middleware.Translations()) - r.Use(middleware.Recovery()) - r.Use(middleware.AccessLog()) + r.Use(middleware.RateLimiter(methodLimiters)) + r.Use(middleware.ContextTimeout(60 * time.Second)) article := v1.NewArticle() tag := v1.NewTag() upload := api.NewUpload() diff --git a/pkg/limiter/limiter.go b/pkg/limiter/limiter.go new file mode 100644 index 0000000..9fcce3a --- /dev/null +++ b/pkg/limiter/limiter.go @@ -0,0 +1,24 @@ +package limiter + +import ( + "github.com/gin-gonic/gin" + "github.com/juju/ratelimit" + "time" +) + +type LimiterIface interface { + Key(c *gin.Context) string + GetBucket(key string) (*ratelimit.Bucket, bool) + AddBuckets(rules ...LimiterBucketRule) LimiterIface +} + +type Limiter struct { + limiterBuckets map[string]*ratelimit.Bucket +} + +type LimiterBucketRule struct { + Key string + FillInterval time.Duration + Capacity int64 + Quantum int64 +} diff --git a/pkg/limiter/method_limiter.go b/pkg/limiter/method_limiter.go new file mode 100644 index 0000000..50cdf88 --- /dev/null +++ b/pkg/limiter/method_limiter.go @@ -0,0 +1,42 @@ +package limiter + +import ( + "github.com/gin-gonic/gin" + "github.com/juju/ratelimit" + "strings" +) + +type MethodLimiter struct { + *Limiter +} + +func NewMethodLimiter() LimiterIface { + return MethodLimiter{ + Limiter: &Limiter{limiterBuckets: make(map[string]*ratelimit.Bucket)}, + } +} + +func (l MethodLimiter) Key(c *gin.Context) string { + uri := c.Request.RequestURI + index := strings.Index(uri, "?") + if index == -1 { + return uri + } + + return uri[:index] +} + +func (l MethodLimiter) GetBucket(key string) (*ratelimit.Bucket, bool) { + bucket, ok := l.limiterBuckets[key] + return bucket, ok +} + +func (l MethodLimiter) AddBuckets(rules ...LimiterBucketRule) LimiterIface { + for _, rule := range rules { + if _, ok := l.limiterBuckets[rule.Key]; !ok { + l.limiterBuckets[rule.Key] = ratelimit.NewBucketWithQuantum(rule.FillInterval, rule.Capacity, rule.Quantum) + } + } + + return l +}