diff --git a/helper/maps/map.go b/helper/maps/map.go index 54557b0..d11f2a4 100644 --- a/helper/maps/map.go +++ b/helper/maps/map.go @@ -49,6 +49,11 @@ func AnyAnyToStrAny(m map[any]any) (r map[string]any) { return } +func IsExists[K comparable, V any](m map[K]V, k K) bool { + _, ok := m[k] + return ok +} + func Reduce[T, V any, K comparable](m map[K]V, fn func(K, V, T) T, r T) T { for k, v := range m { r = fn(k, v, r) diff --git a/helper/number/number_test.go b/helper/number/number_test.go index 5cc7a94..4cfb19f 100644 --- a/helper/number/number_test.go +++ b/helper/number/number_test.go @@ -2,6 +2,7 @@ package number import ( "fmt" + "golang.org/x/exp/constraints" "reflect" "testing" ) @@ -195,10 +196,10 @@ func TestRand(t *testing.T) { } func TestAbs(t *testing.T) { - type args[T Number] struct { + type args[T constraints.Integer | constraints.Float] struct { n T } - type testCase[T Number] struct { + type testCase[T constraints.Integer | constraints.Float] struct { name string args args[T] want T diff --git a/helper/strings/strings.go b/helper/strings/strings.go index c1af80b..8252a48 100644 --- a/helper/strings/strings.go +++ b/helper/strings/strings.go @@ -3,7 +3,9 @@ package strings import ( "crypto/md5" "fmt" + "golang.org/x/exp/constraints" "io" + "strconv" "strings" ) @@ -20,6 +22,17 @@ func Join(s ...string) (str string) { return } +func ToInteger[T constraints.Integer](s string, defaults T) T { + if s == "" { + return defaults + } + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return defaults + } + return T(i) +} + func Md5(str string) string { h := md5.New() _, err := io.WriteString(h, str) diff --git a/helper/strings/strings_test.go b/helper/strings/strings_test.go index a04457c..636bb39 100644 --- a/helper/strings/strings_test.go +++ b/helper/strings/strings_test.go @@ -1,6 +1,9 @@ package strings -import "testing" +import ( + "golang.org/x/exp/constraints" + "testing" +) func TestStrJoin(t *testing.T) { type args struct { @@ -21,3 +24,32 @@ func TestStrJoin(t *testing.T) { }) } } + +func TestToInteger(t *testing.T) { + type args[T constraints.Integer] struct { + s string + z T + } + type testCase[T constraints.Integer] struct { + name string + args args[T] + want T + } + tests := []testCase[int64]{ + { + name: "t1", + args: args[int64]{ + "10", + 0, + }, + want: int64(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ToInteger[int64](tt.args.s, tt.args.z); got != tt.want { + t.Errorf("StrToInt() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/actions/comment.go b/internal/actions/comment.go index 846bdca..8f6faca 100644 --- a/internal/actions/comment.go +++ b/internal/actions/comment.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/fthvgb1/wp-go/helper/slice" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/mail" "github.com/fthvgb1/wp-go/internal/pkg/cache" "github.com/fthvgb1/wp-go/internal/pkg/config" @@ -14,7 +15,6 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" "time" ) @@ -92,8 +92,8 @@ func PostComment(c *gin.Context) { } cc := c.Copy() go func() { - id, err := strconv.ParseUint(i, 10, 64) - if err != nil { + id := str.ToInteger[uint64](i, 0) + if id <= 0 { logs.ErrPrintln(err, "获取文档id", i) return } diff --git a/internal/actions/detail.go b/internal/actions/detail.go index 54e9085..c40647b 100644 --- a/internal/actions/detail.go +++ b/internal/actions/detail.go @@ -2,6 +2,7 @@ package actions import ( "fmt" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/pkg/cache" "github.com/fthvgb1/wp-go/internal/pkg/logs" "github.com/fthvgb1/wp-go/internal/plugins" @@ -10,7 +11,6 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "net/http" - "strconv" ) type detailHandler struct { @@ -47,18 +47,11 @@ func Detail(c *gin.Context) { t := theme.GetTemplateName() theme.Hook(t, code, c, ginH, plugins.Detail, status) }() - id := c.Param("id") - Id := 0 - if id != "" { - Id, err = strconv.Atoi(id) - if err != nil { - return - } - } - ID := uint64(Id) + ID := str.ToInteger[uint64](c.Param("id"), 0) + maxId, err := cache.GetMaxPostId(c) logs.ErrPrintln(err, "get max post id") - if ID > maxId || err != nil { + if ID > maxId || ID <= 0 || err != nil { return } post, err := cache.GetPostById(c, ID) @@ -92,12 +85,7 @@ func Detail(c *gin.Context) { ginH["post"] = post ginH["showComment"] = showComment ginH["prev"] = prev - depth := wpconfig.Options.Value("thread_comments_depth") - d, err := strconv.Atoi(depth) - if err != nil { - logs.ErrPrintln(err, "get comment depth ", depth) - d = 5 - } + d := str.ToInteger(wpconfig.Options.Value("thread_comments_depth"), 5) ginH["maxDep"] = d ginH["next"] = next ginH["user"] = user diff --git a/internal/actions/index.go b/internal/actions/index.go index 47c6b2a..852d758 100644 --- a/internal/actions/index.go +++ b/internal/actions/index.go @@ -3,6 +3,7 @@ package actions import ( "errors" "fmt" + "github.com/fthvgb1/wp-go/helper/maps" "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" str "github.com/fthvgb1/wp-go/helper/strings" @@ -52,13 +53,12 @@ type indexHandle struct { } func newIndexHandle(ctx *gin.Context) *indexHandle { - size := wpconfig.Options.Value("posts_per_page") - si, _ := strconv.Atoi(size) + size := str.ToInteger(wpconfig.Options.Value("posts_per_page"), 10) return &indexHandle{ c: ctx, session: sessions.Default(ctx), page: 1, - pageSize: si, + pageSize: size, paginationStep: 1, titleL: wpconfig.Options.Value("blogname"), titleR: wpconfig.Options.Value("blogdescription"), @@ -93,24 +93,20 @@ func (h *indexHandle) getSearchKey() string { return fmt.Sprintf("action:%s|%s|%s|%s|%s|%s|%d|%d", h.author, h.search, h.orderBy, h.order, h.category, h.categoryType, h.page, h.pageSize) } -var orders = []string{"asc", "desc"} +var orders = map[string]struct{}{"asc": {}, "desc": {}} func (h *indexHandle) parseParams() (err error) { h.order = h.c.Query("order") - - if !slice.IsContained(h.order, orders) { + if !maps.IsExists(orders, h.order) { order := config.GetConfig().PostOrder h.order = "asc" - if order != "" && slice.IsContained(order, orders) { + if order != "" && maps.IsExists(orders, order) { h.order = order } } year := h.c.Param("year") if year != "" { - y, er := strconv.Atoi(year) - if er != nil { - return err - } + y := str.ToInteger(year, -1) if y > time.Now().Year() || y <= 1970 { return errors.New(str.Join("year err : ", year)) } @@ -120,11 +116,8 @@ func (h *indexHandle) parseParams() (err error) { } month := h.c.Param("month") if month != "" { - m, err := strconv.Atoi(month) - if err != nil { - return err - } - if _, ok := months[m]; !ok { + m := str.ToInteger(month, -1) + if !maps.IsExists(months, m) { return errors.New(str.Join("months err ", month)) } @@ -137,27 +130,26 @@ func (h *indexHandle) parseParams() (err error) { h.scene = plugins.Archive } category := h.c.Param("category") - if category == "" { - category = h.c.Param("tag") - if category != "" { - h.scene = plugins.Tag - allNames := cache.AllCategoryTagsNames(h.c, plugins.Tag) - if _, ok := allNames[category]; !ok { - return errors.New(str.Join("not exists tag ", category)) - } - h.categoryType = "post_tag" - h.header = fmt.Sprintf("标签: %s", category) - } - } else { + if category != "" { h.scene = plugins.Category - allNames := cache.AllCategoryTagsNames(h.c, plugins.Category) - if _, ok := allNames[category]; !ok { + if !maps.IsExists(cache.AllCategoryTagsNames(h.c, plugins.Category), category) { return errors.New(str.Join("not exists category ", category)) } h.categoryType = "category" h.header = fmt.Sprintf("分类: %s", category) + h.category = category } - h.category = category + tag := h.c.Param("tag") + if tag != "" { + h.scene = plugins.Tag + if !maps.IsExists(cache.AllCategoryTagsNames(h.c, plugins.Tag), tag) { + return errors.New(str.Join("not exists tag ", tag)) + } + h.categoryType = "post_tag" + h.header = fmt.Sprintf("标签: %s", tag) + h.category = tag + } + username := h.c.Param("author") if username != "" { allUsername, er := cache.GetAllUsername(h.c) @@ -165,7 +157,7 @@ func (h *indexHandle) parseParams() (err error) { err = er return } - if _, ok := allUsername[username]; !ok { + if !maps.IsExists(allUsername, username) { err = errors.New(str.Join("user ", username, " is not exists")) return } @@ -179,9 +171,9 @@ func (h *indexHandle) parseParams() (err error) { "post_author", "=", strconv.FormatUint(user.Id, 10), "int", }) } - if category != "" { + if h.category != "" { h.where = append(h.where, []string{ - "d.name", category, + "d.name", h.category, }, []string{"taxonomy", h.categoryType}) h.join = append(h.join, []string{ "a", "left join", "wp_term_relationships b", "a.Id=b.object_id", @@ -190,7 +182,7 @@ func (h *indexHandle) parseParams() (err error) { }, []string{ "left join", "wp_terms d", "c.term_id=d.term_id", }) - h.setTitleLR(category, wpconfig.Options.Value("blogname")) + h.setTitleLR(h.category, wpconfig.Options.Value("blogname")) } s := h.c.Query("s") if s != "" && strings.Replace(s, " ", "", -1) != "" { @@ -206,15 +198,7 @@ func (h *indexHandle) parseParams() (err error) { h.search = s h.scene = plugins.Search } - p := h.c.Query("paged") - if p == "" { - p = h.c.Param("page") - } - if p != "" { - if pa, err := strconv.Atoi(p); err == nil { - h.page = pa - } - } + h.page = str.ToInteger(h.c.Param("page"), 1) total := int(atomic.LoadInt64(&dao.TotalRaw)) if total > 0 && total < (h.page-1)*h.pageSize { h.page = 1 diff --git a/internal/pkg/cache/feed.go b/internal/pkg/cache/feed.go index 1fa195b..0231ba5 100644 --- a/internal/pkg/cache/feed.go +++ b/internal/pkg/cache/feed.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/fthvgb1/wp-go/cache" "github.com/fthvgb1/wp-go/helper/slice" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/pkg/logs" "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/plugins" @@ -11,7 +12,6 @@ import ( "github.com/fthvgb1/wp-go/plugin/digest" "github.com/fthvgb1/wp-go/rss2" "github.com/gin-gonic/gin" - "strconv" "strings" "time" ) @@ -93,17 +93,10 @@ func feed(arg ...any) (xml []string, err error) { func postFeed(arg ...any) (x string, err error) { c := arg[0].(*gin.Context) id := arg[1].(string) - Id := 0 - if id != "" { - Id, err = strconv.Atoi(id) - if err != nil { - return - } - } - ID := uint64(Id) + ID := str.ToInteger[uint64](id, 0) maxId, err := GetMaxPostId(c) logs.ErrPrintln(err, "get max post id") - if ID > maxId || err != nil { + if ID < 1 || ID > maxId || err != nil { return } post, err := GetPostById(c, ID)