diff --git a/config.example.yaml b/config.example.yaml index 18f4ef1..63fd54a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -27,30 +27,42 @@ ssl: cert: "" key: "" -# 最近文章缓存时间 -recentPostCacheTime: 5m -# 分类缓存时间 -categoryCacheTime: 5m -# 上下篇缓存时间 -contextPostCacheTime: 10h -# 最近评论缓存时间 -recentCommentsCacheTime: 5m -# 摘要缓存时间 -digestCacheTime: 5m +cacheTime: + # 静态资源缓存时间Cache-Control + cacheControl: 5d + # 最近文章缓存时间 + recentPostCacheTime: 5m + # 分类缓存时间 + categoryCacheTime: 5m + # 上下篇缓存时间 + contextPostCacheTime: 10h + # 最近评论缓存时间 + recentCommentsCacheTime: 5m + # 摘要缓存时间 + digestCacheTime: 5m + # 文档列表id页缓存 包括默认列表、分类 + postListCacheTime: 1h + # 搜索文档id缓存时间 + searchPostCacheTime: 5m + # 月归档文章id缓存时间 + monthPostCacheTime: 1h + # 文档数据缓存时间 + postDataCacheTime: 1h + # 文章评论缓存时间 + postCommentsCacheTime: 5m + # 定时清理缓存周期时间 + crontabClearCacheTime: 5m + # 文档最大id缓存时间 + maxPostIdCacheTime: 1h + # 用户信息缓存时间 + userInfoCacheTime: 24h + # 单独评论缓存时间 + commentsCacheTime: 24h + # 主题的页眉图片缓存时间 + themeHeaderImagCacheTime: 5m # 摘要字数 digestWordCount: 300 -# 文档列表id页缓存 包括默认列表、分类 -postListCacheTime: 1h -# 搜索文档id缓存时间 -searchPostCacheTime: 5m -# 月归档文章id缓存时间 -monthPostCacheTime: 1h -# 文档数据缓存时间 -postDataCacheTime: 1h -# 文章评论缓存时间 -postCommentsCacheTime: 5m -# 定时清理缓存周期时间 -crontabClearCacheTime: 5m + # 到达指定并发请求数时随机sleep maxRequestSleepNum: 100 # 随机sleep时间 @@ -59,14 +71,7 @@ sleepTime: [1s,3s] maxRequestNum: 500 # 单ip同时最大搜索请求数 singleIpSearchNum: 10 -# 文档最大id缓存时间 -maxPostIdCacheTime: 1h -# 用户信息缓存时间 -userInfoCacheTime: 24h -# 单独评论缓存时间 -commentsCacheTime: 24h -# 主题的页眉图片缓存时间 -themeHeaderImagCacheTime: 5m + # Gzip gzip: false diff --git a/helper/func.go b/helper/func.go index 302ca66..ff53435 100644 --- a/helper/func.go +++ b/helper/func.go @@ -8,6 +8,13 @@ func ToAny[T any](v T) any { return v } +func Or[T any](is bool, left, right T) T { + if is { + return left + } + return right +} + func StructColumnToSlice[T any, M any](arr []M, field string) (r []T) { for i := 0; i < len(arr); i++ { v := reflect.ValueOf(arr[i]).FieldByName(field).Interface() diff --git a/helper/maps/map.go b/helper/maps/map.go index 54557b0..d11f2a4 100644 --- a/helper/maps/map.go +++ b/helper/maps/map.go @@ -49,6 +49,11 @@ func AnyAnyToStrAny(m map[any]any) (r map[string]any) { return } +func IsExists[K comparable, V any](m map[K]V, k K) bool { + _, ok := m[k] + return ok +} + func Reduce[T, V any, K comparable](m map[K]V, fn func(K, V, T) T, r T) T { for k, v := range m { r = fn(k, v, r) diff --git a/helper/number/number_test.go b/helper/number/number_test.go index 5cc7a94..4cfb19f 100644 --- a/helper/number/number_test.go +++ b/helper/number/number_test.go @@ -2,6 +2,7 @@ package number import ( "fmt" + "golang.org/x/exp/constraints" "reflect" "testing" ) @@ -195,10 +196,10 @@ func TestRand(t *testing.T) { } func TestAbs(t *testing.T) { - type args[T Number] struct { + type args[T constraints.Integer | constraints.Float] struct { n T } - type testCase[T Number] struct { + type testCase[T constraints.Integer | constraints.Float] struct { name string args args[T] want T diff --git a/helper/strings/strings.go b/helper/strings/strings.go index c1af80b..8252a48 100644 --- a/helper/strings/strings.go +++ b/helper/strings/strings.go @@ -3,7 +3,9 @@ package strings import ( "crypto/md5" "fmt" + "golang.org/x/exp/constraints" "io" + "strconv" "strings" ) @@ -20,6 +22,17 @@ func Join(s ...string) (str string) { return } +func ToInteger[T constraints.Integer](s string, defaults T) T { + if s == "" { + return defaults + } + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return defaults + } + return T(i) +} + func Md5(str string) string { h := md5.New() _, err := io.WriteString(h, str) diff --git a/helper/strings/strings_test.go b/helper/strings/strings_test.go index a04457c..636bb39 100644 --- a/helper/strings/strings_test.go +++ b/helper/strings/strings_test.go @@ -1,6 +1,9 @@ package strings -import "testing" +import ( + "golang.org/x/exp/constraints" + "testing" +) func TestStrJoin(t *testing.T) { type args struct { @@ -21,3 +24,32 @@ func TestStrJoin(t *testing.T) { }) } } + +func TestToInteger(t *testing.T) { + type args[T constraints.Integer] struct { + s string + z T + } + type testCase[T constraints.Integer] struct { + name string + args args[T] + want T + } + tests := []testCase[int64]{ + { + name: "t1", + args: args[int64]{ + "10", + 0, + }, + want: int64(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ToInteger[int64](tt.args.s, tt.args.z); got != tt.want { + t.Errorf("StrToInt() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/actions/comment.go b/internal/actions/comment.go index 40a7079..8f6faca 100644 --- a/internal/actions/comment.go +++ b/internal/actions/comment.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/fthvgb1/wp-go/helper/slice" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/mail" "github.com/fthvgb1/wp-go/internal/pkg/cache" "github.com/fthvgb1/wp-go/internal/pkg/config" @@ -14,7 +15,6 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" "time" ) @@ -34,6 +34,7 @@ func PostComment(c *gin.Context) { c.Writer.WriteString(err.Error()) } }() + conf := config.GetConfig() if err != nil { return } @@ -43,7 +44,7 @@ func PostComment(c *gin.Context) { m := c.PostForm("email") comment := c.PostForm("comment") c.Request.Body = io.NopCloser(bytes.NewBuffer(data)) - req, err := http.NewRequest("POST", config.Conf.Load().PostCommentUrl, strings.NewReader(c.Request.PostForm.Encode())) + req, err := http.NewRequest("POST", conf.PostCommentUrl, strings.NewReader(c.Request.PostForm.Encode())) if err != nil { return } @@ -68,7 +69,7 @@ func PostComment(c *gin.Context) { err = er return } - cu, er := url.Parse(config.Conf.Load().PostCommentUrl) + cu, er := url.Parse(conf.PostCommentUrl) if er != nil { err = er return @@ -91,8 +92,8 @@ func PostComment(c *gin.Context) { } cc := c.Copy() go func() { - id, err := strconv.ParseUint(i, 10, 64) - if err != nil { + id := str.ToInteger[uint64](i, 0) + if id <= 0 { logs.ErrPrintln(err, "获取文档id", i) return } @@ -102,8 +103,8 @@ func PostComment(c *gin.Context) { return } su := fmt.Sprintf("%s: %s[%s]发表了评论对文档[%v]的评论", wpconfig.Options.Value("siteurl"), author, m, post.PostTitle) - err = mail.SendMail([]string{config.Conf.Load().Mail.User}, su, comment) - logs.ErrPrintln(err, "发送邮件", config.Conf.Load().Mail.User, su, comment) + err = mail.SendMail([]string{conf.Mail.User}, su, comment) + logs.ErrPrintln(err, "发送邮件", conf.Mail.User, su, comment) }() s, er := io.ReadAll(ress.Body) diff --git a/internal/actions/detail.go b/internal/actions/detail.go index d06b68f..1270b5d 100644 --- a/internal/actions/detail.go +++ b/internal/actions/detail.go @@ -2,15 +2,16 @@ package actions import ( "fmt" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/pkg/cache" "github.com/fthvgb1/wp-go/internal/pkg/logs" + "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/plugins" "github.com/fthvgb1/wp-go/internal/theme" "github.com/fthvgb1/wp-go/internal/wpconfig" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "net/http" - "strconv" ) type detailHandler struct { @@ -19,9 +20,10 @@ type detailHandler struct { func Detail(c *gin.Context) { var err error - recent := cache.RecentPosts(c, 5) + var post models.Posts + recent := cache.RecentPosts(c, 5, true) archive := cache.Archives(c) - categoryItems := cache.Categories(c) + categoryItems := cache.CategoriesTags(c, plugins.Category) recentComments := cache.RecentComments(c, 5) var ginH = gin.H{ "title": wpconfig.Options.Value("blogname"), @@ -29,6 +31,7 @@ func Detail(c *gin.Context) { "archives": archive, "categories": categoryItems, "recentComments": recentComments, + "post": post, } isApproveComment := false status := plugins.Ok @@ -47,21 +50,14 @@ func Detail(c *gin.Context) { t := theme.GetTemplateName() theme.Hook(t, code, c, ginH, plugins.Detail, status) }() - id := c.Param("id") - Id := 0 - if id != "" { - Id, err = strconv.Atoi(id) - if err != nil { - return - } - } - ID := uint64(Id) + ID := str.ToInteger[uint64](c.Param("id"), 0) + maxId, err := cache.GetMaxPostId(c) logs.ErrPrintln(err, "get max post id") - if ID > maxId || err != nil { + if ID > maxId || ID <= 0 || err != nil { return } - post, err := cache.GetPostById(c, ID) + post, err = cache.GetPostById(c, ID) if post.Id == 0 || err != nil || post.PostStatus != "publish" { return } @@ -71,10 +67,13 @@ func Detail(c *gin.Context) { showComment = true } user := cache.GetUserById(c, post.PostAuthor) - plugins.PasswordProjectTitle(&post) - if post.PostPassword != "" && pw != post.PostPassword { - plugins.PasswdProjectContent(&post) - showComment = false + + if post.PostPassword != "" { + plugins.PasswordProjectTitle(&post) + if pw != post.PostPassword { + plugins.PasswdProjectContent(&post) + showComment = false + } } else if s, ok := cache.NewCommentCache().Get(c, c.Request.URL.RawQuery); ok && s != "" && (post.PostPassword == "" || post.PostPassword != "" && pw == post.PostPassword) { c.Writer.WriteHeader(http.StatusOK) c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -92,12 +91,7 @@ func Detail(c *gin.Context) { ginH["post"] = post ginH["showComment"] = showComment ginH["prev"] = prev - depth := wpconfig.Options.Value("thread_comments_depth") - d, err := strconv.Atoi(depth) - if err != nil { - logs.ErrPrintln(err, "get comment depth ", depth) - d = 5 - } + d := str.ToInteger(wpconfig.Options.Value("thread_comments_depth"), 5) ginH["maxDep"] = d ginH["next"] = next ginH["user"] = user diff --git a/internal/actions/index.go b/internal/actions/index.go index ec19906..bc77bc6 100644 --- a/internal/actions/index.go +++ b/internal/actions/index.go @@ -3,6 +3,7 @@ package actions import ( "errors" "fmt" + "github.com/fthvgb1/wp-go/helper/maps" "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" str "github.com/fthvgb1/wp-go/helper/strings" @@ -52,13 +53,12 @@ type indexHandle struct { } func newIndexHandle(ctx *gin.Context) *indexHandle { - size := wpconfig.Options.Value("posts_per_page") - si, _ := strconv.Atoi(size) + size := str.ToInteger(wpconfig.Options.Value("posts_per_page"), 10) return &indexHandle{ c: ctx, session: sessions.Default(ctx), page: 1, - pageSize: si, + pageSize: size, paginationStep: 1, titleL: wpconfig.Options.Value("blogname"), titleR: wpconfig.Options.Value("blogdescription"), @@ -93,24 +93,20 @@ func (h *indexHandle) getSearchKey() string { return fmt.Sprintf("action:%s|%s|%s|%s|%s|%s|%d|%d", h.author, h.search, h.orderBy, h.order, h.category, h.categoryType, h.page, h.pageSize) } -var orders = []string{"asc", "desc"} +var orders = map[string]struct{}{"asc": {}, "desc": {}} func (h *indexHandle) parseParams() (err error) { h.order = h.c.Query("order") - - if !slice.IsContained(h.order, orders) { - order := config.Conf.Load().PostOrder + if !maps.IsExists(orders, h.order) { + order := config.GetConfig().PostOrder h.order = "asc" - if order != "" && slice.IsContained(order, orders) { + if order != "" && maps.IsExists(orders, order) { h.order = order } } year := h.c.Param("year") if year != "" { - y, er := strconv.Atoi(year) - if er != nil { - return err - } + y := str.ToInteger(year, -1) if y > time.Now().Year() || y <= 1970 { return errors.New(str.Join("year err : ", year)) } @@ -120,11 +116,8 @@ func (h *indexHandle) parseParams() (err error) { } month := h.c.Param("month") if month != "" { - m, err := strconv.Atoi(month) - if err != nil { - return err - } - if _, ok := months[m]; !ok { + m := str.ToInteger(month, -1) + if !maps.IsExists(months, m) { return errors.New(str.Join("months err ", month)) } @@ -137,27 +130,26 @@ func (h *indexHandle) parseParams() (err error) { h.scene = plugins.Archive } category := h.c.Param("category") - if category == "" { - category = h.c.Param("tag") - if category != "" { - h.scene = plugins.Tag - allNames := cache.AllTagsNames(h.c) - if _, ok := allNames[category]; !ok { - return errors.New(str.Join("not exists tag ", category)) - } - h.categoryType = "post_tag" - h.header = fmt.Sprintf("标签: %s", category) - } - } else { + if category != "" { h.scene = plugins.Category - allNames := cache.AllCategoryNames(h.c) - if _, ok := allNames[category]; !ok { + if !maps.IsExists(cache.AllCategoryTagsNames(h.c, plugins.Category), category) { return errors.New(str.Join("not exists category ", category)) } h.categoryType = "category" h.header = fmt.Sprintf("分类: %s", category) + h.category = category } - h.category = category + tag := h.c.Param("tag") + if tag != "" { + h.scene = plugins.Tag + if !maps.IsExists(cache.AllCategoryTagsNames(h.c, plugins.Tag), tag) { + return errors.New(str.Join("not exists tag ", tag)) + } + h.categoryType = "post_tag" + h.header = fmt.Sprintf("标签: %s", tag) + h.category = tag + } + username := h.c.Param("author") if username != "" { allUsername, er := cache.GetAllUsername(h.c) @@ -165,7 +157,7 @@ func (h *indexHandle) parseParams() (err error) { err = er return } - if _, ok := allUsername[username]; !ok { + if !maps.IsExists(allUsername, username) { err = errors.New(str.Join("user ", username, " is not exists")) return } @@ -179,9 +171,9 @@ func (h *indexHandle) parseParams() (err error) { "post_author", "=", strconv.FormatUint(user.Id, 10), "int", }) } - if category != "" { + if h.category != "" { h.where = append(h.where, []string{ - "d.name", category, + "d.name", h.category, }, []string{"taxonomy", h.categoryType}) h.join = append(h.join, []string{ "a", "left join", "wp_term_relationships b", "a.Id=b.object_id", @@ -190,7 +182,7 @@ func (h *indexHandle) parseParams() (err error) { }, []string{ "left join", "wp_terms d", "c.term_id=d.term_id", }) - h.setTitleLR(category, wpconfig.Options.Value("blogname")) + h.setTitleLR(h.category, wpconfig.Options.Value("blogname")) } s := h.c.Query("s") if s != "" && strings.Replace(s, " ", "", -1) != "" { @@ -206,15 +198,7 @@ func (h *indexHandle) parseParams() (err error) { h.search = s h.scene = plugins.Search } - p := h.c.Query("paged") - if p == "" { - p = h.c.Param("page") - } - if p != "" { - if pa, err := strconv.Atoi(p); err == nil { - h.page = pa - } - } + h.page = str.ToInteger(h.c.Param("page"), 1) total := int(atomic.LoadInt64(&dao.TotalRaw)) if total > 0 && total < (h.page-1)*h.pageSize { h.page = 1 @@ -236,8 +220,8 @@ func Index(c *gin.Context) { var totalRaw int var err error archive := cache.Archives(c) - recent := cache.RecentPosts(c, 5) - categoryItems := cache.Categories(c) + recent := cache.RecentPosts(c, 5, true) + categoryItems := cache.CategoriesTags(c, plugins.Category) recentComments := cache.RecentComments(c, 5) ginH := gin.H{ "err": err, @@ -247,6 +231,7 @@ func Index(c *gin.Context) { "search": h.search, "header": h.header, "recentComments": recentComments, + "posts": posts, } defer func() { code := http.StatusOK @@ -292,18 +277,16 @@ func Index(c *gin.Context) { pw := h.session.Get("post_password") plug := plugins.NewPostPlugin(c, h.scene) for i, post := range posts { - plugins.PasswordProjectTitle(&posts[i]) - if post.PostPassword != "" && pw != post.PostPassword { - plugins.PasswdProjectContent(&posts[i]) + if post.PostPassword != "" { + plugins.PasswordProjectTitle(&posts[i]) + if pw != post.PostPassword { + plugins.PasswdProjectContent(&posts[i]) + } } else { plugins.ApplyPlugin(plug, &posts[i]) } } - for i, post := range recent { - if post.PostPassword != "" && pw != post.PostPassword { - plugins.PasswdProjectContent(&recent[i]) - } - } + q := c.Request.URL.Query().Encode() if q != "" { q = fmt.Sprintf("?%s", q) diff --git a/internal/cmd/main.go b/internal/cmd/main.go index b65c393..a75b098 100644 --- a/internal/cmd/main.go +++ b/internal/cmd/main.go @@ -55,11 +55,11 @@ func initConf(c string) (err error) { return } - err = db.InitDb() + database, err := db.InitDb() if err != nil { return } - model.InitDB(db.NewSqlxDb(db.Db)) + model.InitDB(model.NewSqlxQuery(database)) err = wpconfig.InitOptions() if err != nil { return @@ -72,7 +72,7 @@ func initConf(c string) (err error) { } func cronClearCache() { - t := time.NewTicker(config.Conf.Load().CrontabClearCacheTime) + t := time.NewTicker(config.GetConfig().CacheTime.CrontabClearCacheTime) for { select { case <-t.C: @@ -85,7 +85,7 @@ func cronClearCache() { func flushCache() { defer func() { if r := recover(); r != nil { - err := mail.SendMail([]string{config.Conf.Load().Mail.User}, "清空缓存失败", fmt.Sprintf("err:[%s]", r)) + err := mail.SendMail([]string{config.GetConfig().Mail.User}, "清空缓存失败", fmt.Sprintf("err:[%s]", r)) logs.ErrPrintln(err, "发邮件失败") } }() @@ -129,7 +129,7 @@ func signalNotify() { func main() { go signalNotify() Gin, reloadFn := route.SetupRouter() - c := config.Conf.Load() + c := config.GetConfig() middleWareReloadFn = reloadFn if c.Ssl.Key != "" && c.Ssl.Cert != "" { err := Gin.RunTLS(address, c.Ssl.Cert, c.Ssl.Key) diff --git a/internal/cmd/route/route.go b/internal/cmd/route/route.go index 2763e11..78ff6a5 100644 --- a/internal/cmd/route/route.go +++ b/internal/cmd/route/route.go @@ -18,7 +18,7 @@ func SetupRouter() (*gin.Engine, func()) { // Disable Console Color // gin.DisableConsoleColor() r := gin.New() - c := config.Conf.Load() + c := config.GetConfig() if len(c.TrustIps) > 0 { err := r.SetTrustedProxies(c.TrustIps) if err != nil { @@ -29,7 +29,7 @@ func SetupRouter() (*gin.Engine, func()) { r.HTMLRender = theme.GetTemplate() validServerName, reloadValidServerNameFn := middleware.ValidateServerNames() - fl, flReload := middleware.FlowLimit(c.MaxRequestSleepNum, c.MaxRequestNum, c.SleepTime) + fl, flReload := middleware.FlowLimit(c.MaxRequestSleepNum, c.MaxRequestNum, c.CacheTime.SleepTime) r.Use( gin.Logger(), validServerName, @@ -76,15 +76,15 @@ func SetupRouter() (*gin.Engine, func()) { r.GET("/p/:id/feed", actions.PostFeed) r.GET("/feed", actions.Feed) r.GET("/comments/feed", actions.CommentsFeed) - cfl, _ := middleware.FlowLimit(c.MaxRequestSleepNum, 5, c.SleepTime) + cfl, _ := middleware.FlowLimit(c.MaxRequestSleepNum, 5, c.CacheTime.SleepTime) r.POST("/comment", cfl, actions.PostComment) if c.Pprof != "" { pprof.Register(r, c.Pprof) } fn := func() { reloadValidServerNameFn() - c := config.Conf.Load() - flReload(c.MaxRequestSleepNum, c.MaxRequestNum, c.SleepTime) + c := config.GetConfig() + flReload(c.MaxRequestSleepNum, c.MaxRequestNum, c.CacheTime.SleepTime) slRload(c.SingleIpSearchNum) } return r, fn diff --git a/internal/mail/mail.go b/internal/mail/mail.go index 81595aa..45dc867 100644 --- a/internal/mail/mail.go +++ b/internal/mail/mail.go @@ -18,7 +18,7 @@ func SendMail(mailTo []string, subject string, body string, files ...string) err m := gomail.NewMessage( gomail.SetEncoding(gomail.Base64), ) - c := config.Conf.Load() + c := config.GetConfig() m.SetHeader("From", m.FormatAddress(c.Mail.User, c.Mail.Alias, diff --git a/internal/middleware/sendmail.go b/internal/middleware/sendmail.go index 0870d1f..aded58d 100644 --- a/internal/middleware/sendmail.go +++ b/internal/middleware/sendmail.go @@ -43,7 +43,7 @@ func RecoverAndSendMail(w io.Writer) func(ctx *gin.Context) { ) er := mail.SendMail( - []string{config.Conf.Load().Mail.User}, + []string{config.GetConfig().Mail.User}, fmt.Sprintf("%s%s %s 发生错误", fmt.Sprintf(wpconfig.Options.Value("siteurl")), c.FullPath(), time.Now().Format(time.RFC1123Z)), content) if er != nil { diff --git a/internal/middleware/staticFileCache.go b/internal/middleware/staticFileCache.go index 9d3bdc2..993325b 100644 --- a/internal/middleware/staticFileCache.go +++ b/internal/middleware/staticFileCache.go @@ -1,15 +1,26 @@ package middleware import ( - "github.com/fthvgb1/wp-go/helper/slice" + "fmt" + "github.com/fthvgb1/wp-go/internal/pkg/config" "github.com/gin-gonic/gin" "strings" ) +var path = map[string]struct{}{ + "wp-includes": {}, + "wp-content": {}, + "favicon.ico": {}, +} + func SetStaticFileCache(c *gin.Context) { f := strings.Split(strings.TrimLeft(c.FullPath(), "/"), "/") - if len(f) > 0 && slice.IsContained(f[0], []string{"wp-includes", "wp-content", "favicon.ico"}) { - c.Header("Cache-Control", "private, max-age=86400") + if _, ok := path[f[0]]; ok { + t := config.GetConfig().CacheTime.CacheControl + if t > 0 { + c.Header("Cache-Control", fmt.Sprintf("private, max-age=%d", int(t.Seconds()))) + } } + c.Next() } diff --git a/internal/middleware/validateservername.go b/internal/middleware/validateservername.go index 444b80d..6e3f261 100644 --- a/internal/middleware/validateservername.go +++ b/internal/middleware/validateservername.go @@ -11,7 +11,7 @@ import ( func ValidateServerNames() (func(ctx *gin.Context), func()) { var serverName safety.Map[string, struct{}] fn := func() { - r := config.Conf.Load().TrustServerNames + r := config.GetConfig().TrustServerNames if len(r) > 0 { for _, name := range r { serverName.Store(name, struct{}{}) diff --git a/internal/pkg/cache/cache.go b/internal/pkg/cache/cache.go index 8c2e992..67d6ab0 100644 --- a/internal/pkg/cache/cache.go +++ b/internal/pkg/cache/cache.go @@ -3,11 +3,13 @@ package cache import ( "context" "github.com/fthvgb1/wp-go/cache" + "github.com/fthvgb1/wp-go/helper" "github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/internal/pkg/config" "github.com/fthvgb1/wp-go/internal/pkg/dao" "github.com/fthvgb1/wp-go/internal/pkg/logs" "github.com/fthvgb1/wp-go/internal/pkg/models" + "github.com/fthvgb1/wp-go/internal/plugins" "sync" "time" ) @@ -46,39 +48,43 @@ var headerImagesCache *cache.MapCache[string, []models.PostThumbnail] var ctx context.Context func InitActionsCommonCache() { - c := config.Conf.Load() + c := config.GetConfig() archivesCaches = &Arch{ - mutex: &sync.Mutex{}, - setCacheFunc: dao.Archives, + mutex: &sync.Mutex{}, + fn: dao.Archives, } - searchPostIdsCache = cache.NewMemoryMapCacheByFn[string](dao.SearchPostIds, c.SearchPostCacheTime) + searchPostIdsCache = cache.NewMemoryMapCacheByFn[string](dao.SearchPostIds, c.CacheTime.SearchPostCacheTime) - postListIdsCache = cache.NewMemoryMapCacheByFn[string](dao.SearchPostIds, c.PostListCacheTime) + postListIdsCache = cache.NewMemoryMapCacheByFn[string](dao.SearchPostIds, c.CacheTime.PostListCacheTime) - monthPostsCache = cache.NewMemoryMapCacheByFn[string](dao.MonthPost, c.MonthPostCacheTime) + monthPostsCache = cache.NewMemoryMapCacheByFn[string](dao.MonthPost, c.CacheTime.MonthPostCacheTime) - postContextCache = cache.NewMemoryMapCacheByFn[uint64](dao.GetPostContext, c.ContextPostCacheTime) + postContextCache = cache.NewMemoryMapCacheByFn[uint64](dao.GetPostContext, c.CacheTime.ContextPostCacheTime) - postsCache = cache.NewMemoryMapCacheByBatchFn(dao.GetPostsByIds, c.PostDataCacheTime) + postsCache = cache.NewMemoryMapCacheByBatchFn(dao.GetPostsByIds, c.CacheTime.PostDataCacheTime) - postMetaCache = cache.NewMemoryMapCacheByBatchFn(dao.GetPostMetaByPostIds, c.PostDataCacheTime) + postMetaCache = cache.NewMemoryMapCacheByBatchFn(dao.GetPostMetaByPostIds, c.CacheTime.PostDataCacheTime) - categoryAndTagsCaches = cache.NewVarCache(dao.CategoriesAndTags, c.CategoryCacheTime) + categoryAndTagsCaches = cache.NewVarCache(dao.CategoriesAndTags, c.CacheTime.CategoryCacheTime) - recentPostsCaches = cache.NewVarCache(dao.RecentPosts, c.RecentPostCacheTime) + recentPostsCaches = cache.NewVarCache(dao.RecentPosts, c.CacheTime.RecentPostCacheTime) - recentCommentsCaches = cache.NewVarCache(dao.RecentComments, c.RecentCommentsCacheTime) + recentCommentsCaches = cache.NewVarCache(dao.RecentComments, c.CacheTime.RecentCommentsCacheTime) - postCommentCaches = cache.NewMemoryMapCacheByFn[uint64](dao.PostComments, c.PostCommentsCacheTime) + postCommentCaches = cache.NewMemoryMapCacheByFn[uint64](dao.PostComments, c.CacheTime.PostCommentsCacheTime) - maxPostIdCache = cache.NewVarCache(dao.GetMaxPostId, c.MaxPostIdCacheTime) + maxPostIdCache = cache.NewVarCache(dao.GetMaxPostId, c.CacheTime.MaxPostIdCacheTime) - usersCache = cache.NewMemoryMapCacheByFn[uint64](dao.GetUserById, c.UserInfoCacheTime) + usersCache = cache.NewMemoryMapCacheByFn[uint64](dao.GetUserById, c.CacheTime.UserInfoCacheTime) - usersNameCache = cache.NewMemoryMapCacheByFn[string](dao.GetUserByName, c.UserInfoCacheTime) + usersNameCache = cache.NewMemoryMapCacheByFn[string](dao.GetUserByName, c.CacheTime.UserInfoCacheTime) - commentsCache = cache.NewMemoryMapCacheByBatchFn(dao.GetCommentByIds, c.CommentsCacheTime) + commentsCache = cache.NewMemoryMapCacheByBatchFn(dao.GetCommentByIds, c.CacheTime.CommentsCacheTime) + + allUsernameCache = cache.NewVarCache(dao.AllUsername, c.CacheTime.UserInfoCacheTime) + + headerImagesCache = cache.NewMemoryMapCacheByFn[string](getHeaderImages, c.CacheTime.ThemeHeaderImagCacheTime) feedCache = cache.NewVarCache(feed, time.Hour) @@ -88,10 +94,6 @@ func InitActionsCommonCache() { newCommentCache = cache.NewMemoryMapCacheByFn[string, string](nil, 15*time.Minute) - allUsernameCache = cache.NewVarCache(dao.AllUsername, c.UserInfoCacheTime) - - headerImagesCache = cache.NewMemoryMapCacheByFn[string](getHeaderImages, c.ThemeHeaderImagCacheTime) - ctx = context.Background() InitFeed() @@ -132,62 +134,44 @@ func Archives(ctx context.Context) (r []models.PostArchive) { } type Arch struct { - data []models.PostArchive - mutex *sync.Mutex - setCacheFunc func(context.Context) ([]models.PostArchive, error) - month time.Month + data []models.PostArchive + mutex *sync.Mutex + fn func(context.Context) ([]models.PostArchive, error) + month time.Month } -func (c *Arch) getArchiveCache(ctx context.Context) []models.PostArchive { - l := len(c.data) +func (a *Arch) getArchiveCache(ctx context.Context) []models.PostArchive { + l := len(a.data) m := time.Now().Month() - if l > 0 && c.month != m || l < 1 { - r, err := c.setCacheFunc(ctx) + if l > 0 && a.month != m || l < 1 { + r, err := a.fn(ctx) if err != nil { logs.ErrPrintln(err, "set cache err[%s]") return nil } - c.mutex.Lock() - defer c.mutex.Unlock() - c.month = m - c.data = r + a.mutex.Lock() + defer a.mutex.Unlock() + a.month = m + a.data = r } - return c.data + return a.data } -func Categories(ctx context.Context) []models.TermsMy { +func CategoriesTags(ctx context.Context, t ...int) []models.TermsMy { r, err := categoryAndTagsCaches.GetCache(ctx, time.Second, ctx) logs.ErrPrintln(err, "get category err") - r = slice.Filter(r, func(my models.TermsMy) bool { - return my.Taxonomy == "category" - }) + if len(t) > 0 { + return slice.Filter(r, func(my models.TermsMy) bool { + return helper.Or(t[0] == plugins.Tag, "post_tag", "category") == my.Taxonomy + }) + } return r } - -func Tags(ctx context.Context) []models.TermsMy { - r, err := categoryAndTagsCaches.GetCache(ctx, time.Second, ctx) - logs.ErrPrintln(err, "get category err") - r = slice.Filter(r, func(my models.TermsMy) bool { - return my.Taxonomy == "post_tag" - }) - return r -} -func AllTagsNames(ctx context.Context) map[string]struct{} { +func AllCategoryTagsNames(ctx context.Context, c int) map[string]struct{} { r, err := categoryAndTagsCaches.GetCache(ctx, time.Second, ctx) logs.ErrPrintln(err, "get category err") return slice.FilterAndToMap(r, func(t models.TermsMy) (string, struct{}, bool) { - if t.Taxonomy == "post_tag" { - return t.Name, struct{}{}, true - } - return "", struct{}{}, false - }) -} - -func AllCategoryNames(ctx context.Context) map[string]struct{} { - r, err := categoryAndTagsCaches.GetCache(ctx, time.Second, ctx) - logs.ErrPrintln(err, "get category err") - return slice.FilterAndToMap(r, func(t models.TermsMy) (string, struct{}, bool) { - if t.Taxonomy == "category" { + if helper.Or(c == plugins.Tag, "post_tag", "category") == t.Taxonomy { return t.Name, struct{}{}, true } return "", struct{}{}, false diff --git a/internal/pkg/cache/feed.go b/internal/pkg/cache/feed.go index 1fa195b..e4a5518 100644 --- a/internal/pkg/cache/feed.go +++ b/internal/pkg/cache/feed.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/fthvgb1/wp-go/cache" "github.com/fthvgb1/wp-go/helper/slice" + str "github.com/fthvgb1/wp-go/helper/strings" "github.com/fthvgb1/wp-go/internal/pkg/logs" "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/plugins" @@ -11,7 +12,6 @@ import ( "github.com/fthvgb1/wp-go/plugin/digest" "github.com/fthvgb1/wp-go/rss2" "github.com/gin-gonic/gin" - "strconv" "strings" "time" ) @@ -46,7 +46,7 @@ func PostFeedCache() *cache.MapCache[string, string] { func feed(arg ...any) (xml []string, err error) { c := arg[0].(*gin.Context) - r := RecentPosts(c, 10) + r := RecentPosts(c, 10, true) ids := slice.Map(r, func(t models.Posts) uint64 { return t.Id }) @@ -54,21 +54,22 @@ func feed(arg ...any) (xml []string, err error) { if err != nil { return } + site := wpconfig.Options.Value("siteurl") rs := templateRss rs.LastBuildDate = time.Now().Format(timeFormat) rs.Items = slice.Map(posts, func(t models.Posts) rss2.Item { desc := "无法提供摘要。这是一篇受保护的文章。" - plugins.PasswordProjectTitle(&t) if t.PostPassword != "" { + plugins.PasswordProjectTitle(&t) plugins.PasswdProjectContent(&t) } else { desc = digest.Raw(t.PostContent, 55, fmt.Sprintf("/p/%d", t.Id)) } l := "" if t.CommentStatus == "open" && t.CommentCount > 0 { - l = fmt.Sprintf("%s/p/%d#comments", wpconfig.Options.Value("siteurl"), t.Id) + l = fmt.Sprintf("%s/p/%d#comments", site, t.Id) } else if t.CommentStatus == "open" && t.CommentCount == 0 { - l = fmt.Sprintf("%s/p/%d#respond", wpconfig.Options.Value("siteurl"), t.Id) + l = fmt.Sprintf("%s/p/%d#respond", site, t.Id) } user := GetUserById(c, t.PostAuthor) @@ -80,8 +81,8 @@ func feed(arg ...any) (xml []string, err error) { Content: t.PostContent, Category: strings.Join(t.Categories, "、"), CommentLink: l, - CommentRss: fmt.Sprintf("%s/p/%d/feed", wpconfig.Options.Value("siteurl"), t.Id), - Link: fmt.Sprintf("%s/p/%d", wpconfig.Options.Value("siteurl"), t.Id), + CommentRss: fmt.Sprintf("%s/p/%d/feed", site, t.Id), + Link: fmt.Sprintf("%s/p/%d", site, t.Id), Description: desc, PubDate: t.PostDateGmt.Format(timeFormat), } @@ -93,42 +94,36 @@ func feed(arg ...any) (xml []string, err error) { func postFeed(arg ...any) (x string, err error) { c := arg[0].(*gin.Context) id := arg[1].(string) - Id := 0 - if id != "" { - Id, err = strconv.Atoi(id) - if err != nil { - return - } - } - ID := uint64(Id) + ID := str.ToInteger[uint64](id, 0) maxId, err := GetMaxPostId(c) logs.ErrPrintln(err, "get max post id") - if ID > maxId || err != nil { + if ID < 1 || ID > maxId || err != nil { return } post, err := GetPostById(c, ID) if post.Id == 0 || err != nil { return } - plugins.PasswordProjectTitle(&post) comments, err := PostComments(c, post.Id) if err != nil { return } rs := templateRss + site := wpconfig.Options.Value("siteurl") rs.Title = fmt.Sprintf("《%s》的评论", post.PostTitle) - rs.AtomLink = fmt.Sprintf("%s/p/%d/feed", wpconfig.Options.Value("siteurl"), post.Id) - rs.Link = fmt.Sprintf("%s/p/%d", wpconfig.Options.Value("siteurl"), post.Id) + rs.AtomLink = fmt.Sprintf("%s/p/%d/feed", site, post.Id) + rs.Link = fmt.Sprintf("%s/p/%d", site, post.Id) rs.LastBuildDate = time.Now().Format(timeFormat) if post.PostPassword != "" { + plugins.PasswordProjectTitle(&post) + plugins.PasswdProjectContent(&post) if len(comments) > 0 { - plugins.PasswdProjectContent(&post) t := comments[len(comments)-1] rs.Items = []rss2.Item{ { Title: fmt.Sprintf("评价者:%s", t.CommentAuthor), - Link: fmt.Sprintf("%s/p/%d#comment-%d", wpconfig.Options.Value("siteurl"), post.Id, t.CommentId), + Link: fmt.Sprintf("%s/p/%d#comment-%d", site, post.Id, t.CommentId), Creator: t.CommentAuthor, PubDate: t.CommentDateGmt.Format(timeFormat), Guid: fmt.Sprintf("%s#comment-%d", post.Guid, t.CommentId), @@ -141,7 +136,7 @@ func postFeed(arg ...any) (x string, err error) { rs.Items = slice.Map(comments, func(t models.Comments) rss2.Item { return rss2.Item{ Title: fmt.Sprintf("评价者:%s", t.CommentAuthor), - Link: fmt.Sprintf("%s/p/%d#comment-%d", wpconfig.Options.Value("siteurl"), post.Id, t.CommentId), + Link: fmt.Sprintf("%s/p/%d#comment-%d", site, post.Id, t.CommentId), Creator: t.CommentAuthor, PubDate: t.CommentDateGmt.Format(timeFormat), Guid: fmt.Sprintf("%s#comment-%d", post.Guid, t.CommentId), @@ -160,7 +155,8 @@ func commentsFeed(args ...any) (r []string, err error) { rs := templateRss rs.Title = fmt.Sprintf("\"%s\"的评论", wpconfig.Options.Value("blogname")) rs.LastBuildDate = time.Now().Format(timeFormat) - rs.AtomLink = fmt.Sprintf("%s/comments/feed", wpconfig.Options.Value("siteurl")) + site := wpconfig.Options.Value("siteurl") + rs.AtomLink = fmt.Sprintf("%s/comments/feed", site) com, err := GetCommentByIds(c, slice.Map(commens, func(t models.Comments) uint64 { return t.CommentId })) @@ -169,10 +165,10 @@ func commentsFeed(args ...any) (r []string, err error) { } rs.Items = slice.Map(com, func(t models.Comments) rss2.Item { post, _ := GetPostById(c, t.CommentPostId) - plugins.PasswordProjectTitle(&post) desc := "评论受保护:要查看请输入密码。" content := t.CommentContent if post.PostPassword != "" { + plugins.PasswordProjectTitle(&post) plugins.PasswdProjectContent(&post) content = post.PostContent } else { @@ -181,7 +177,7 @@ func commentsFeed(args ...any) (r []string, err error) { } return rss2.Item{ Title: fmt.Sprintf("%s对《%s》的评论", t.CommentAuthor, post.PostTitle), - Link: fmt.Sprintf("%s/p/%d#comment-%d", wpconfig.Options.Value("siteurl"), post.Id, t.CommentId), + Link: fmt.Sprintf("%s/p/%d#comment-%d", site, post.Id, t.CommentId), Creator: t.CommentAuthor, Description: desc, PubDate: t.CommentDateGmt.Format(timeFormat), diff --git a/internal/pkg/cache/posts.go b/internal/pkg/cache/posts.go index b12581f..d4ea374 100644 --- a/internal/pkg/cache/posts.go +++ b/internal/pkg/cache/posts.go @@ -6,6 +6,7 @@ import ( "github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/internal/pkg/logs" "github.com/fthvgb1/wp-go/internal/pkg/models" + "github.com/fthvgb1/wp-go/internal/plugins" "github.com/gin-gonic/gin" "time" ) @@ -42,11 +43,23 @@ func GetMaxPostId(ctx *gin.Context) (uint64, error) { return maxPostIdCache.GetCache(ctx, time.Second, ctx) } -func RecentPosts(ctx context.Context, n int) (r []models.Posts) { - r, err := recentPostsCaches.GetCache(ctx, time.Second, ctx) +func RecentPosts(ctx context.Context, n int, project bool) (r []models.Posts) { + nn := n + if nn <= 5 { + nn = 10 + } + r, err := recentPostsCaches.GetCache(ctx, time.Second, ctx, nn) if n < len(r) { r = r[:n] } + if project { + r = slice.Map(r, func(t models.Posts) models.Posts { + if t.PostPassword != "" { + plugins.PasswordProjectTitle(&t) + } + return t + }) + } logs.ErrPrintln(err, "get recent post") return } diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index ea0d070..5901c6a 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -8,41 +8,50 @@ import ( "time" ) -var Conf safety.Var[Config] +var config safety.Var[Config] + +func GetConfig() Config { + return config.Load() +} type Config struct { - Ssl Ssl `yaml:"ssl"` - Mysql Mysql `yaml:"mysql"` - Mail Mail `yaml:"mail"` + Ssl Ssl `yaml:"ssl"` + Mysql Mysql `yaml:"mysql"` + Mail Mail `yaml:"mail"` + CacheTime CacheTime `yaml:"cacheTime"` + DigestWordCount int `yaml:"digestWordCount"` + MaxRequestSleepNum int64 `yaml:"maxRequestSleepNum"` + MaxRequestNum int64 `yaml:"maxRequestNum"` + SingleIpSearchNum int64 `yaml:"singleIpSearchNum"` + Gzip bool `yaml:"gzip"` + PostCommentUrl string `yaml:"postCommentUrl"` + TrustIps []string `yaml:"trustIps"` + TrustServerNames []string `yaml:"trustServerNames"` + Theme string `yaml:"theme"` + PostOrder string `yaml:"postOrder"` + UploadDir string `yaml:"uploadDir"` + Pprof string `yaml:"pprof"` +} + +type CacheTime struct { + CacheControl time.Duration `yaml:"cacheControl"` RecentPostCacheTime time.Duration `yaml:"recentPostCacheTime"` CategoryCacheTime time.Duration `yaml:"categoryCacheTime"` ArchiveCacheTime time.Duration `yaml:"archiveCacheTime"` ContextPostCacheTime time.Duration `yaml:"contextPostCacheTime"` RecentCommentsCacheTime time.Duration `yaml:"recentCommentsCacheTime"` DigestCacheTime time.Duration `yaml:"digestCacheTime"` - DigestWordCount int `yaml:"digestWordCount"` PostListCacheTime time.Duration `yaml:"postListCacheTime"` SearchPostCacheTime time.Duration `yaml:"searchPostCacheTime"` MonthPostCacheTime time.Duration `yaml:"monthPostCacheTime"` PostDataCacheTime time.Duration `yaml:"postDataCacheTime"` PostCommentsCacheTime time.Duration `yaml:"postCommentsCacheTime"` CrontabClearCacheTime time.Duration `yaml:"crontabClearCacheTime"` - MaxRequestSleepNum int64 `yaml:"maxRequestSleepNum"` - SleepTime []time.Duration `yaml:"sleepTime"` - MaxRequestNum int64 `yaml:"maxRequestNum"` - SingleIpSearchNum int64 `yaml:"singleIpSearchNum"` MaxPostIdCacheTime time.Duration `yaml:"maxPostIdCacheTime"` UserInfoCacheTime time.Duration `yaml:"userInfoCacheTime"` CommentsCacheTime time.Duration `yaml:"commentsCacheTime"` ThemeHeaderImagCacheTime time.Duration `yaml:"themeHeaderImagCacheTime"` - Gzip bool `yaml:"gzip"` - PostCommentUrl string `yaml:"postCommentUrl"` - TrustIps []string `yaml:"trustIps"` - TrustServerNames []string `yaml:"trustServerNames"` - Theme string `yaml:"theme"` - PostOrder string `yaml:"postOrder"` - UploadDir string `yaml:"uploadDir"` - Pprof string `yaml:"pprof"` + SleepTime []time.Duration `yaml:"sleepTime"` } type Ssl struct { @@ -77,7 +86,7 @@ func InitConfig(conf string) error { if err != nil { return err } - Conf.Store(c) + config.Store(c) return nil } diff --git a/internal/pkg/dao/comments.go b/internal/pkg/dao/comments.go index 91ac576..86c1e7e 100644 --- a/internal/pkg/dao/comments.go +++ b/internal/pkg/dao/comments.go @@ -2,22 +2,26 @@ package dao import ( "context" + "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/model" - "strconv" ) // RecentComments // param context.Context func RecentComments(a ...any) (r []models.Comments, err error) { ctx := a[0].(context.Context) - return model.Find[models.Comments](ctx, model.SqlBuilder{ - {"comment_approved", "1"}, - {"post_status", "publish"}, - }, "comment_ID,comment_author,comment_post_ID,post_title", "", model.SqlBuilder{{"comment_date_gmt", "desc"}}, model.SqlBuilder{ - {"a", "left join", "wp_posts b", "a.comment_post_ID=b.ID"}, - }, nil, 10) + return model.Finds[models.Comments](ctx, model.Conditions( + model.Where(model.SqlBuilder{ + {"comment_approved", "1"}, + {"post_status", "publish"}, + }), + model.Fields("comment_ID,comment_author,comment_post_ID,post_title"), + model.Order(model.SqlBuilder{{"comment_date_gmt", "desc"}}), + model.Join(model.SqlBuilder{{"a", "left join", "wp_posts b", "a.comment_post_ID=b.ID"}}), + model.Limit(10), + )) } // PostComments @@ -26,13 +30,17 @@ func RecentComments(a ...any) (r []models.Comments, err error) { func PostComments(args ...any) ([]uint64, error) { ctx := args[0].(context.Context) postId := args[1].(uint64) - r, err := model.Find[models.Comments](ctx, model.SqlBuilder{ - {"comment_approved", "1"}, - {"comment_post_ID", "=", strconv.FormatUint(postId, 10), "int"}, - }, "comment_ID", "", model.SqlBuilder{ - {"comment_date_gmt", "asc"}, - {"comment_ID", "asc"}, - }, nil, nil, 0) + r, err := model.Finds[models.Comments](ctx, model.Conditions( + model.Where(model.SqlBuilder{ + {"comment_approved", "1"}, + {"comment_post_ID", "=", number.ToString(postId), "int"}, + }), + model.Fields("comment_ID"), + model.Order(model.SqlBuilder{ + {"comment_date_gmt", "asc"}, + {"comment_ID", "asc"}, + })), + ) if err != nil { return nil, err } diff --git a/internal/pkg/dao/common.go b/internal/pkg/dao/common.go index 710e8ee..6907c61 100644 --- a/internal/pkg/dao/common.go +++ b/internal/pkg/dao/common.go @@ -2,7 +2,6 @@ package dao import ( "context" - "fmt" "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/wpconfig" "github.com/fthvgb1/wp-go/model" @@ -20,23 +19,21 @@ type PostContext struct { Next models.Posts } -func PasswordProjectTitle(post *models.Posts) { - if post.PostPassword != "" { - post.PostTitle = fmt.Sprintf("密码保护:%s", post.PostTitle) - } -} - func CategoriesAndTags(a ...any) (terms []models.TermsMy, err error) { ctx := a[0].(context.Context) var in = []any{"category", "post_tag"} - terms, err = model.Find[models.TermsMy](ctx, model.SqlBuilder{ - {"tt.count", ">", "0", "int"}, - {"tt.taxonomy", "in", ""}, - }, "t.term_id", "", model.SqlBuilder{ - {"t.name", "asc"}, - }, model.SqlBuilder{ - {"t", "inner join", "wp_term_taxonomy tt", "t.term_id = tt.term_id"}, - }, nil, 0, in) + terms, err = model.Finds[models.TermsMy](ctx, model.Conditions( + model.Where(model.SqlBuilder{ + {"tt.count", ">", "0", "int"}, + {"tt.taxonomy", "in", ""}, + }), + model.Fields("t.term_id"), + model.Order(model.SqlBuilder{{"t.name", "asc"}}), + model.Join(model.SqlBuilder{ + {"t", "inner join", "wp_term_taxonomy tt", "t.term_id = tt.term_id"}, + }), + model.In(in), + )) for i := 0; i < len(terms); i++ { if v, ok := wpconfig.Terms.Load(terms[i].Terms.TermId); ok { terms[i].Terms = v @@ -49,7 +46,13 @@ func CategoriesAndTags(a ...any) (terms []models.TermsMy, err error) { } func Archives(ctx context.Context) ([]models.PostArchive, error) { - return model.Find[models.PostArchive](ctx, model.SqlBuilder{ - {"post_type", "post"}, {"post_status", "publish"}, - }, "YEAR(post_date) AS `year`, MONTH(post_date) AS `month`, count(ID) as posts", "year,month", model.SqlBuilder{{"year", "desc"}, {"month", "desc"}}, nil, nil, 0) + return model.Finds[models.PostArchive](ctx, model.Conditions( + model.Where(model.SqlBuilder{ + {"post_type", "post"}, + {"post_status", "publish"}, + }), + model.Fields("YEAR(post_date) AS `year`, MONTH(post_date) AS `month`, count(ID) as posts"), + model.Group("year,month"), + model.Order(model.SqlBuilder{{"year", "desc"}, {"month", "desc"}}), + )) } diff --git a/internal/pkg/dao/postmeta.go b/internal/pkg/dao/postmeta.go index ceb567c..dba8891 100644 --- a/internal/pkg/dao/postmeta.go +++ b/internal/pkg/dao/postmeta.go @@ -14,9 +14,10 @@ func GetPostMetaByPostIds(args ...any) (r map[uint64]map[string]any, err error) r = make(map[uint64]map[string]any) ctx := args[0].(context.Context) ids := args[1].([]uint64) - rr, err := model.Find[models.PostMeta](ctx, model.SqlBuilder{ - {"post_id", "in", ""}, - }, "*", "", nil, nil, nil, 0, slice.ToAnySlice(ids)) + rr, err := model.Finds[models.PostMeta](ctx, model.Conditions( + model.Where(model.SqlBuilder{{"post_id", "in", ""}}), + model.In(slice.ToAnySlice(ids)), + )) if err != nil { return } @@ -24,6 +25,7 @@ func GetPostMetaByPostIds(args ...any) (r map[uint64]map[string]any, err error) if _, ok := r[postmeta.PostId]; !ok { r[postmeta.PostId] = make(map[string]any) } + r[postmeta.PostId][postmeta.MetaKey] = postmeta.MetaValue if postmeta.MetaKey == "_wp_attachment_metadata" { metadata, err := plugins.UnPHPSerialize[models.WpAttachmentMetadata](postmeta.MetaValue) if err != nil { @@ -31,11 +33,7 @@ func GetPostMetaByPostIds(args ...any) (r map[uint64]map[string]any, err error) continue } r[postmeta.PostId][postmeta.MetaKey] = metadata - - } else { - r[postmeta.PostId][postmeta.MetaKey] = postmeta.MetaValue } - } return } diff --git a/internal/pkg/dao/posts.go b/internal/pkg/dao/posts.go index b92df99..a767ed1 100644 --- a/internal/pkg/dao/posts.go +++ b/internal/pkg/dao/posts.go @@ -13,20 +13,21 @@ import ( "time" ) -func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) { - ctx := ids[0].(context.Context) +func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { + ctx := a[0].(context.Context) m = make(map[uint64]models.Posts) - id := ids[1].([]uint64) - arg := slice.ToAnySlice(id) - rawPosts, err := model.Find[models.Posts](ctx, model.SqlBuilder{{ - "Id", "in", "", - }}, "a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`", "", nil, model.SqlBuilder{{ - "a", "left join", "wp_term_relationships b", "a.Id=b.object_id", - }, { - "left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id", - }, { - "left join", "wp_terms d", "c.term_id=d.term_id", - }}, nil, 0, arg) + ids := a[1].([]uint64) + rawPosts, err := model.Finds[models.Posts](ctx, model.Conditions( + model.Where(model.SqlBuilder{{"Id", "in", ""}}), + model.Join(model.SqlBuilder{ + {"a", "left join", "wp_term_relationships b", "a.Id=b.object_id"}, + {"left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id"}, + {"left join", "wp_terms d", "c.term_id=d.term_id"}, + }), + model.Fields("a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`"), + model.In(slice.ToAnySlice(ids)), + )) + if err != nil { return m, err } @@ -45,7 +46,7 @@ func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) { } //host, _ := wpconfig.Options.Load("siteurl") host := "" - meta, _ := GetPostMetaByPostIds(ctx, id) + meta, _ := GetPostMetaByPostIds(ctx, ids) for k, pp := range postsMap { if len(pp.Categories) > 0 { t := make([]string, 0, len(pp.Categories)) @@ -97,7 +98,11 @@ func SearchPostIds(args ...any) (ids PostIds, err error) { join := args[5].(model.SqlBuilder) postType := args[6].([]any) postStatus := args[7].([]any) - res, total, err := model.SimplePagination[models.Posts](ctx, where, "ID", "", page, limit, order, join, nil, postType, postStatus) + res, total, err := model.SimplePagination[models.Posts]( + ctx, where, "ID", + "", page, limit, order, + join, nil, postType, postStatus, + ) for _, posts := range res { ids.Ids = append(ids.Ids, posts.Id) } @@ -112,7 +117,10 @@ func SearchPostIds(args ...any) (ids PostIds, err error) { func GetMaxPostId(a ...any) (uint64, error) { ctx := a[0].(context.Context) - r, err := model.SimpleFind[models.Posts](ctx, model.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, "max(ID) ID") + r, err := model.SimpleFind[models.Posts](ctx, + model.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, + "max(ID) ID", + ) var id uint64 if len(r) > 0 { id = r[0].Id @@ -122,14 +130,16 @@ func GetMaxPostId(a ...any) (uint64, error) { func RecentPosts(a ...any) (r []models.Posts, err error) { ctx := a[0].(context.Context) - r, err = model.Find[models.Posts](ctx, model.SqlBuilder{{ - "post_type", "post", - }, {"post_status", "publish"}}, "ID,post_title,post_password", "", model.SqlBuilder{{"post_date", "desc"}}, nil, nil, 10) - for i, post := range r { - if post.PostPassword != "" { - PasswordProjectTitle(&r[i]) - } - } + num := a[1].(int) + r, err = model.Finds[models.Posts](ctx, model.Conditions( + model.Where(model.SqlBuilder{ + {"post_type", "post"}, + {"post_status", "publish"}, + }), + model.Fields("ID,post_title,post_password"), + model.Order(model.SqlBuilder{{"post_date", "desc"}}), + model.Limit(num), + )) return } @@ -169,19 +179,15 @@ func MonthPost(args ...any) (r []uint64, err error) { ctx := args[0].(context.Context) year, month := args[1].(string), args[2].(string) where := model.SqlBuilder{ - {"post_type", "in", ""}, - {"post_status", "in", ""}, + {"post_type", "post"}, + {"post_status", "publish"}, {"year(post_date)", year}, {"month(post_date)", month}, } - postType := []any{"post"} - status := []any{"publish"} - ids, err := model.Find[models.Posts](ctx, where, "ID", "", model.SqlBuilder{{"Id", "asc"}}, nil, nil, 0, postType, status) - if err != nil { - return - } - for _, post := range ids { - r = append(r, post.Id) - } - return + return model.Column[models.Posts, uint64](ctx, func(v models.Posts) (uint64, bool) { + return v.Id, true + }, model.Conditions( + model.Fields("ID"), + model.Where(where), + )) } diff --git a/internal/pkg/db/db.go b/internal/pkg/db/db.go index a7196c9..3907203 100644 --- a/internal/pkg/db/db.go +++ b/internal/pkg/db/db.go @@ -1,78 +1,32 @@ package db import ( - "context" - "fmt" "github.com/fthvgb1/wp-go/internal/pkg/config" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" - "log" - "os" - "strconv" - "strings" ) -var Db *sqlx.DB +var db *sqlx.DB -type SqlxDb struct { - sqlx *sqlx.DB -} - -func NewSqlxDb(sqlx *sqlx.DB) *SqlxDb { - return &SqlxDb{sqlx: sqlx} -} - -func (r SqlxDb) Select(ctx context.Context, dest any, sql string, params ...any) error { - if os.Getenv("SHOW_SQL") == "true" { - go log.Println(formatSql(sql, params)) - } - return r.sqlx.Select(dest, sql, params...) -} - -func (r SqlxDb) Get(ctx context.Context, dest any, sql string, params ...any) error { - if os.Getenv("SHOW_SQL") == "true" { - go log.Println(formatSql(sql, params)) - } - return r.sqlx.Get(dest, sql, params...) -} - -func formatSql(sql string, params []any) string { - for _, param := range params { - switch param.(type) { - case string: - sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) - case int64: - sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) - case int: - sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) - case uint64: - sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) - case float64: - sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) - } - } - return sql -} - -func InitDb() error { - c := config.Conf.Load() +func InitDb() (*sqlx.DB, error) { + c := config.GetConfig() dsn := c.Mysql.Dsn.GetDsn() var err error - Db, err = sqlx.Open("mysql", dsn) + db, err = sqlx.Open("mysql", dsn) if err != nil { - return err + return nil, err } if c.Mysql.Pool.ConnMaxIdleTime != 0 { - Db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime) + db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime) } if c.Mysql.Pool.MaxIdleConn != 0 { - Db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn) + db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn) } if c.Mysql.Pool.MaxOpenConn != 0 { - Db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn) + db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn) } if c.Mysql.Pool.ConnMaxLifetime != 0 { - Db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime) + db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime) } - return err + return db, err } diff --git a/internal/pkg/models/wp_posts.go b/internal/pkg/models/wp_posts.go index df3df27..10cc88a 100644 --- a/internal/pkg/models/wp_posts.go +++ b/internal/pkg/models/wp_posts.go @@ -3,6 +3,7 @@ package models import "time" type Posts struct { + post Id uint64 `gorm:"column:ID" db:"ID" json:"ID" form:"ID"` PostAuthor uint64 `gorm:"column:post_author" db:"post_author" json:"post_author" form:"post_author"` PostDate time.Time `gorm:"column:post_date" db:"post_date" json:"post_date" form:"post_date"` @@ -47,23 +48,19 @@ type PostThumbnail struct { OriginAttachmentData WpAttachmentMetadata } -func (w Posts) PrimaryKey() string { +type post struct { +} + +func (w post) PrimaryKey() string { return "ID" } -func (w Posts) Table() string { - return "wp_posts" -} - -func (w PostArchive) PrimaryKey() string { - return "ID" -} - -func (w PostArchive) Table() string { +func (w post) Table() string { return "wp_posts" } type PostArchive struct { + post Year string `db:"year"` Month string `db:"month"` Posts int `db:"posts"` diff --git a/internal/plugins/digest.go b/internal/plugins/digest.go index 90311ce..d2d97fd 100644 --- a/internal/plugins/digest.go +++ b/internal/plugins/digest.go @@ -17,7 +17,7 @@ var ctx context.Context func InitDigestCache() { ctx = context.Background() - digestCache = cache.NewMemoryMapCacheByFn[uint64](digestRaw, config.Conf.Load().DigestCacheTime) + digestCache = cache.NewMemoryMapCacheByFn[uint64](digestRaw, config.GetConfig().CacheTime.DigestCacheTime) } func ClearDigestCache() { @@ -30,7 +30,7 @@ func FlushCache() { func digestRaw(arg ...any) (string, error) { str := arg[0].(string) id := arg[1].(uint64) - limit := config.Conf.Load().DigestWordCount + limit := config.GetConfig().DigestWordCount if limit < 0 { return str, nil } else if limit == 0 { diff --git a/internal/plugins/posts.go b/internal/plugins/posts.go index fde5ea0..711728d 100644 --- a/internal/plugins/posts.go +++ b/internal/plugins/posts.go @@ -19,18 +19,14 @@ func ApplyPlugin(p *Plugin[models.Posts], post *models.Posts) { } func PasswordProjectTitle(post *models.Posts) { - if post.PostPassword != "" { - post.PostTitle = fmt.Sprintf("密码保护:%s", post.PostTitle) - } + post.PostTitle = fmt.Sprintf("密码保护:%s", post.PostTitle) } func PasswdProjectContent(post *models.Posts) { - if post.PostContent != "" { - format := ` + format := `

此内容受密码保护。如需查阅,请在下列字段中输入您的密码。

` - post.PostContent = fmt.Sprintf(format, post.Id, post.Id) - } + post.PostContent = fmt.Sprintf(format, post.Id, post.Id) } diff --git a/internal/theme/theme.go b/internal/theme/theme.go index bc0e23a..f38d3fb 100644 --- a/internal/theme/theme.go +++ b/internal/theme/theme.go @@ -11,7 +11,7 @@ func InitThemeAndTemplateFuncMap() { } func GetTemplateName() string { - tmlp := config.Conf.Load().Theme + tmlp := config.GetConfig().Theme if tmlp == "" { tmlp = wpconfig.Options.Value("template") } diff --git a/internal/theme/twentyseventeen/twentyseventeen.go b/internal/theme/twentyseventeen/twentyseventeen.go index 430dea9..7f3df32 100644 --- a/internal/theme/twentyseventeen/twentyseventeen.go +++ b/internal/theme/twentyseventeen/twentyseventeen.go @@ -139,23 +139,16 @@ func (h handle) bodyClass() string { s := "" switch h.scene { case plugins.Search: + s = "search-no-results" if len(h.ginH["posts"].([]models.Posts)) > 0 { s = "search-results" - } else { - s = "search-no-results" } - case plugins.Category: + case plugins.Category, plugins.Tag: cat := h.c.Param("category") - _, cate := slice.SearchFirst(cache.Categories(h.c), func(my models.TermsMy) bool { - return my.Name == cat - }) - if cate.Slug[0] != '%' { - s = cate.Slug + if cat == "" { + cat = h.c.Param("tag") } - s = fmt.Sprintf("category-%d %v", cate.Terms.TermId, s) - case plugins.Tag: - cat := h.c.Param("tag") - _, cate := slice.SearchFirst(cache.Tags(h.c), func(my models.TermsMy) bool { + _, cate := slice.SearchFirst(cache.CategoriesTags(h.c, h.scene), func(my models.TermsMy) bool { return my.Name == cat }) if cate.Slug[0] != '%' { diff --git a/model/query.go b/model/query.go index 2f0278e..2373666 100644 --- a/model/query.go +++ b/model/query.go @@ -8,7 +8,7 @@ import ( "strings" ) -func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { +func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { var rr T var w string var args []any @@ -55,11 +55,11 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr if group == "" { tpx := "select count(*) n from %s %s %s limit 1" sq := fmt.Sprintf(tpx, rr.Table(), j, w) - err = globalBb.Get(ctx, &n, sq, args...) + err = db.Get(ctx, &n, sq, args...) } else { tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) - err = globalBb.Get(ctx, &n, sq, args...) + err = db.Get(ctx, &n, sq, args...) } if err != nil { @@ -78,13 +78,18 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr } tp := "select %s from %s %s %s %s %s %s limit %d,%d" sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) - err = globalBb.Select(ctx, &r, sq, args...) + err = db.Select(ctx, &r, sq, args...) if err != nil { return } return } +func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { + r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...) + return +} + func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { var r T sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) diff --git a/model/query_test.go b/model/query_test.go index 8e335a1..1aafe69 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -3,15 +3,11 @@ package model import ( "context" "database/sql" - "fmt" "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" - "log" "reflect" - "strconv" - "strings" "testing" "time" ) @@ -102,40 +98,6 @@ func (p post) Table() string { return "wp_posts" } -type SqlxDb struct { - sqlx *sqlx.DB -} - -var Db *SqlxDb - -func (r SqlxDb) Select(_ context.Context, dest any, sql string, params ...any) error { - log.Println(formatSql(sql, params)) - return r.sqlx.Select(dest, sql, params...) -} - -func (r SqlxDb) Get(_ context.Context, dest any, sql string, params ...any) error { - log.Println(formatSql(sql, params)) - return r.sqlx.Get(dest, sql, params...) -} - -func formatSql(sql string, params []any) string { - for _, param := range params { - switch param.(type) { - case string: - sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) - case int64: - sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) - case int: - sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) - case uint64: - sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) - case float64: - sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) - } - } - return sql -} - var ctx = context.Background() func init() { @@ -143,8 +105,7 @@ func init() { if err != nil { panic(err) } - Db = &SqlxDb{db} - InitDB(Db) + InitDB(NewSqlxQuery(db)) } func TestFind(t *testing.T) { type args struct { diff --git a/model/querycondition.go b/model/querycondition.go index 6b720a3..1f611ea 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/fthvgb1/wp-go/helper/slice" "strings" ) @@ -11,6 +12,19 @@ import ( // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { + r, err = finds[T](globalBb, ctx, q) + return +} + +// DBFind 同 Finds 使用指定 db 查询 +// +// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 +func DBFind[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { + r, err = finds[T](db, ctx, q) + return +} + +func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { var rr T w := "" var args []any @@ -48,25 +62,22 @@ func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { l = fmt.Sprintf(" %s offset %d", l, q.offset) } sq := fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) - err = globalBb.Select(ctx, &r, sq, args...) + err = db.Select(ctx, &r, sq, args...) return } -// ChunkFind 分片查询并直接返回所有结果 -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { +func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { i := 1 var rr []T var total int var offset int for { if 1 == i { - rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) } else { q.offset = offset q.limit = perLimit - rr, err = Finds[T](ctx, q) + rr, err = finds[T](db, ctx, q) } offset += perLimit if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { @@ -81,10 +92,39 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r return } +// ChunkFind 分片查询并直接返回所有结果 +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { + r, err = chunkFind[T](globalBb, ctx, perLimit, q) + return +} + +// DBChunkFind 同 ChunkFind +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func DBChunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { + r, err = chunkFind[T](db, ctx, perLimit, q) + return +} + // Chunk 分片查询并函数过虑返回新类型的切片 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { + r, err = chunk(globalBb, ctx, perLimit, fn, q) + return +} + +// DBChunk 同 Chunk +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func DBChunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { + r, err = chunk(db, ctx, perLimit, fn, q) + return +} + +func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { i := 1 var rr []T var count int @@ -92,11 +132,11 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R var offset int for { if 1 == i { - rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) } else { q.offset = offset q.limit = perLimit - rr, err = Finds[T](ctx, q) + rr, err = finds[T](db, ctx, q) } offset += perLimit if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { @@ -123,3 +163,26 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) } + +// DBPagination 同 Pagination 方便多个db使用 +// +// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 +func DBPagination[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) { + return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) +} + +func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) { + return column[V, T](globalBb, ctx, fn, q) +} +func DBColumn[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { + return column[V, T](db, ctx, fn, q) +} + +func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { + res, err := finds[V](db, ctx, q) + if err != nil { + return nil, err + } + r = slice.FilterAndMap(res, fn) + return +} diff --git a/model/querycondition_test.go b/model/querycondition_test.go index 9829844..1bc7c9b 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -230,3 +230,46 @@ func TestPagination(t *testing.T) { }) } } + +func TestColumn(t *testing.T) { + type args[V Model, T any] struct { + ctx context.Context + fn func(V) (T, bool) + q *QueryCondition + } + type testCase[V Model, T any] struct { + name string + args args[V, T] + wantR []T + wantErr bool + } + tests := []testCase[post, uint64]{ + { + name: "t1", + args: args[post, uint64]{ + ctx: ctx, + fn: func(t post) (uint64, bool) { + return t.Id, true + }, + q: Conditions( + Where(SqlBuilder{ + {"ID", "<", "200", "int"}, + }), + ), + }, + wantR: []uint64{63, 64, 190, 193}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := Column[post](tt.args.ctx, tt.args.fn, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("Column() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("Column() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +} diff --git a/model/sqxquery.go b/model/sqxquery.go new file mode 100644 index 0000000..9b7ac29 --- /dev/null +++ b/model/sqxquery.go @@ -0,0 +1,51 @@ +package model + +import ( + "context" + "fmt" + "github.com/jmoiron/sqlx" + "log" + "os" + "strconv" + "strings" +) + +type SqlxQuery struct { + sqlx *sqlx.DB +} + +func NewSqlxQuery(sqlx *sqlx.DB) SqlxQuery { + return SqlxQuery{sqlx: sqlx} +} + +func (r SqlxQuery) Select(ctx context.Context, dest any, sql string, params ...any) error { + if os.Getenv("SHOW_SQL") == "true" { + go log.Println(formatSql(sql, params)) + } + return r.sqlx.Select(dest, sql, params...) +} + +func (r SqlxQuery) Get(ctx context.Context, dest any, sql string, params ...any) error { + if os.Getenv("SHOW_SQL") == "true" { + go log.Println(formatSql(sql, params)) + } + return r.sqlx.Get(dest, sql, params...) +} + +func formatSql(sql string, params []any) string { + for _, param := range params { + switch param.(type) { + case string: + sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) + case int64: + sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) + case int: + sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) + case uint64: + sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) + case float64: + sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) + } + } + return sql +}