From c759dfcdf6214c35cb21dad2a1f8c913f3c297f7 Mon Sep 17 00:00:00 2001 From: xing Date: Fri, 4 Nov 2022 17:42:11 +0800 Subject: [PATCH] having --- actions/common/comments.go | 6 ++-- actions/common/common.go | 4 +-- actions/common/posts.go | 10 +++--- actions/common/users.go | 2 +- models/globalInit.go | 4 +-- models/model.go | 70 ++++++++++++++++++++++++++++---------- 6 files changed, 65 insertions(+), 31 deletions(-) diff --git a/actions/common/comments.go b/actions/common/comments.go index f99a84d..6d27be7 100644 --- a/actions/common/comments.go +++ b/actions/common/comments.go @@ -23,7 +23,7 @@ func recentComments(...any) (r []models.WpComments, err error) { {"post_status", "publish"}, }, "comment_ID,comment_author,comment_post_ID,post_title", "", models.SqlBuilder{{"comment_date_gmt", "desc"}}, models.SqlBuilder{ {"a", "left join", "wp_posts b", "a.comment_post_ID=b.ID"}, - }, 10) + }, nil, 10) } func PostComments(ctx context.Context, Id uint64) ([]models.WpComments, error) { @@ -42,7 +42,7 @@ func postComments(args ...any) ([]uint64, error) { }, "comment_ID", "", models.SqlBuilder{ {"comment_date_gmt", "asc"}, {"comment_ID", "asc"}, - }, nil, 0) + }, nil, nil, 0) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func getCommentByIds(args ...any) (map[uint64]models.WpComments, error) { m := make(map[uint64]models.WpComments) r, err := models.Find[models.WpComments](models.SqlBuilder{ {"comment_ID", "in", ""}, {"comment_approved", "1"}, - }, "*", "", nil, nil, 0, helper.SliceMap(ids, helper.ToAny[uint64])) + }, "*", "", nil, nil, nil, 0, helper.SliceMap(ids, helper.ToAny[uint64])) if err != nil { return m, err } diff --git a/actions/common/common.go b/actions/common/common.go index 1fbdccb..c374556 100644 --- a/actions/common/common.go +++ b/actions/common/common.go @@ -105,7 +105,7 @@ type PostContext struct { func archives() ([]models.PostArchive, error) { return models.Find[models.PostArchive](models.SqlBuilder{ {"post_type", "post"}, {"post_status", "publish"}, - }, "YEAR(post_date) AS `year`, MONTH(post_date) AS `month`, count(ID) as posts", "year,month", models.SqlBuilder{{"year", "desc"}, {"month", "desc"}}, nil, 0) + }, "YEAR(post_date) AS `year`, MONTH(post_date) AS `month`, count(ID) as posts", "year,month", models.SqlBuilder{{"year", "desc"}, {"month", "desc"}}, nil, nil, 0) } func Archives() (r []models.PostArchive) { @@ -127,7 +127,7 @@ func categories(...any) (terms []models.WpTermsMy, err error) { {"t.name", "asc"}, }, models.SqlBuilder{ {"t", "inner join", "wp_term_taxonomy tt", "t.term_id = tt.term_id"}, - }, 0, in) + }, nil, 0, in) for i := 0; i < len(terms); i++ { if v, ok := models.Terms[terms[i].WpTerms.TermId]; ok { terms[i].WpTerms = v diff --git a/actions/common/posts.go b/actions/common/posts.go index 5856996..85e3a6d 100644 --- a/actions/common/posts.go +++ b/actions/common/posts.go @@ -42,7 +42,7 @@ func getPostsByIds(ids ...any) (m map[uint64]models.WpPosts, err error) { "left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id", }, { "left join", "wp_terms d", "c.term_id=d.term_id", - }}, 0, arg) + }}, nil, 0, arg) if err != nil { return m, err } @@ -97,7 +97,7 @@ func searchPostIds(args ...any) (ids PostIds, err error) { join := args[4].(models.SqlBuilder) postType := args[5].([]any) postStatus := args[6].([]any) - res, total, err := models.SimplePagination[models.WpPosts](where, "ID", "", page, limit, order, join, postType, postStatus) + res, total, err := models.SimplePagination[models.WpPosts](where, "ID", "", page, limit, order, join, nil, postType, postStatus) for _, posts := range res { ids.Ids = append(ids.Ids, posts.Id) } @@ -109,7 +109,7 @@ func searchPostIds(args ...any) (ids PostIds, err error) { } func getMaxPostId(...any) ([]uint64, error) { - r, err := models.Find[models.WpPosts](models.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, "max(ID) ID", "", nil, nil, 0) + r, err := models.Find[models.WpPosts](models.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, "max(ID) ID", "", nil, nil, nil, 0) var id uint64 if len(r) > 0 { id = r[0].Id @@ -133,7 +133,7 @@ func RecentPosts(ctx context.Context, n int) (r []models.WpPosts) { func recentPosts(...any) (r []models.WpPosts, err error) { r, err = models.Find[models.WpPosts](models.SqlBuilder{{ "post_type", "post", - }, {"post_status", "publish"}}, "ID,post_title,post_password", "", models.SqlBuilder{{"post_date", "desc"}}, nil, 10) + }, {"post_status", "publish"}}, "ID,post_title,post_password", "", models.SqlBuilder{{"post_date", "desc"}}, nil, nil, 10) for i, post := range r { if post.PostPassword != "" { PasswordProjectTitle(&r[i]) @@ -207,7 +207,7 @@ func monthPost(args ...any) (r []uint64, err error) { } postType := []any{"post"} status := []any{"publish"} - ids, err := models.Find[models.WpPosts](where, "ID", "", models.SqlBuilder{{"Id", "asc"}}, nil, 0, postType, status) + ids, err := models.Find[models.WpPosts](where, "ID", "", models.SqlBuilder{{"Id", "asc"}}, nil, nil, 0, postType, status) if err != nil { return } diff --git a/actions/common/users.go b/actions/common/users.go index a07dc6a..288942c 100644 --- a/actions/common/users.go +++ b/actions/common/users.go @@ -9,7 +9,7 @@ import ( func getUsers(...any) (m map[uint64]models.WpUsers, err error) { m = make(map[uint64]models.WpUsers) - r, err := models.Find[models.WpUsers](nil, "*", "", nil, nil, 0) + r, err := models.Find[models.WpUsers](nil, "*", "", nil, nil, nil, 0) for _, user := range r { m[user.Id] = user } diff --git a/models/globalInit.go b/models/globalInit.go index 2be76e9..31c90d9 100644 --- a/models/globalInit.go +++ b/models/globalInit.go @@ -22,14 +22,14 @@ func InitOptions() error { } func InitTerms() (err error) { - terms, err := Find[WpTerms](nil, "*", "", nil, nil, 0) + terms, err := Find[WpTerms](nil, "*", "", nil, nil, nil, 0) if err != nil { return err } for _, wpTerms := range terms { Terms[wpTerms.TermId] = wpTerms } - termTax, err := Find[WpTermTaxonomy](nil, "*", "", nil, nil, 0) + termTax, err := Find[WpTermTaxonomy](nil, "*", "", nil, nil, nil, 0) if err != nil { return err } diff --git a/models/model.go b/models/model.go index 6e47cd6..5bd783e 100644 --- a/models/model.go +++ b/models/model.go @@ -18,7 +18,7 @@ type Model interface { } type ParseWhere interface { - ParseWhere(...[]any) (string, []any, error) + ParseWhere(*[][]any) (string, []any, error) } type SqlBuilder [][]string @@ -40,10 +40,10 @@ func (w SqlBuilder) parseField(ss []string, s *strings.Builder) { } } -func (w SqlBuilder) parseIn(ss []string, s *strings.Builder, c *int, args *[]any, in [][]any) (t bool) { - if helper.IsContainInArr(ss[1], []string{"in", "not in"}) && len(in) > 0 { +func (w SqlBuilder) parseIn(ss []string, s *strings.Builder, c *int, args *[]any, in *[][]any) (t bool) { + if helper.IsContainInArr(ss[1], []string{"in", "not in"}) && len(*in) > 0 { s.WriteString(" (") - for _, p := range in[*c] { + for _, p := range (*in)[*c] { s.WriteString("?,") *args = append(*args, p) } @@ -76,7 +76,7 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error { return nil } -func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any, error) { +func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) { var s strings.Builder args := make([]any, 0, len(w)) c := 0 @@ -138,6 +138,9 @@ func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any, error) { s.WriteString(ss) ss = s.String() } + if len(*in) > c { + *in = (*in)[c:] + } return ss, args, nil } @@ -178,12 +181,22 @@ func (w SqlBuilder) parseJoin() string { return s.String() } -func SimplePagination[T Model](where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, in ...[]any) (r []T, total int, err error) { +func SimplePagination[T Model](where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { var rr T - w, args, err := where.ParseWhere(in...) + w, args, err := where.ParseWhere(&in) if err != nil { return r, total, err } + h := "" + if having != nil { + hh, arg, err := having.ParseWhere(&in) + if err != nil { + return r, total, err + } + args = append(args, arg...) + h = strings.Replace(hh, " where", " having", 1) + } + n := struct { N int `db:"n" json:"n"` }{} @@ -194,15 +207,27 @@ func SimplePagination[T Model](where ParseWhere, fields, group string, page, pag g.WriteString(group) groupBy = g.String() } + if having != nil { + tm := map[string]struct{}{} + for _, s := range strings.Split(group, ",") { + tm[s] = struct{}{} + } + for _, ss := range having { + if _, ok := tm[ss[0]]; !ok { + group = fmt.Sprintf("%s,%s", group, ss[0]) + } + } + group = strings.Trim(group, ",") + } j := join.parseJoin() if group == "" { tpx := "select count(*) n from %s %s %s limit 1" sq := fmt.Sprintf(tpx, rr.Table(), j, w) err = db.Db.Get(&n, sq, args...) } else { - tpx := "select count(*) n from (select %s from %s %s %s %s ) %s" + tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" rand.Seed(int64(time.Now().Nanosecond())) - sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, fmt.Sprintf("table%d", rand.Int())) + sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) err = db.Db.Get(&n, sq, args...) } @@ -220,8 +245,8 @@ func SimplePagination[T Model](where ParseWhere, fields, group string, page, pag if offset >= total { return } - tp := "select %s from %s %s %s %s %s limit %d,%d" - sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, order.parseOrderBy(), offset, pageSize) + tp := "select %s from %s %s %s %s %s %s limit %d,%d" + sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) err = db.Db.Select(&r, sql, args...) if err != nil { return @@ -241,7 +266,7 @@ func FindOneById[T Model, I helper.IntNumber](id I) (T, error) { func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ...[]any) (T, error) { var r T - w, args, err := where.ParseWhere(in...) + w, args, err := where.ParseWhere(&in) if err != nil { return r, err } @@ -256,7 +281,7 @@ func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ... func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { var r T - w, args, err := where.ParseWhere(in...) + w, args, err := where.ParseWhere(&in) if err != nil { return r, err } @@ -272,7 +297,7 @@ func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { func SimpleFind[T Model](where ParseWhere, fields string, in ...[]any) ([]T, error) { var r []T var rr T - w, args, err := where.ParseWhere(in...) + w, args, err := where.ParseWhere(&in) if err != nil { return r, err } @@ -296,16 +321,25 @@ func Select[T Model](sql string, params ...any) ([]T, error) { return r, nil } -func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, limit int, in ...[]any) (r []T, err error) { +func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { var rr T w := "" var args []any if where != nil { - w, args, err = where.ParseWhere(in...) + w, args, err = where.ParseWhere(&in) if err != nil { return r, err } } + h := "" + if having != nil { + hh, arg, err := having.ParseWhere(&in) + if err != nil { + return r, err + } + args = append(args, arg...) + h = strings.Replace(hh, " where", " having", 1) + } j := join.parseJoin() groupBy := "" @@ -315,12 +349,12 @@ func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, joi g.WriteString(group) groupBy = g.String() } - tp := "select %s from %s %s %s %s %s %s" + tp := "select %s from %s %s %s %s %s %s %s" l := "" if limit > 0 { l = fmt.Sprintf(" limit %d", limit) } - sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, order.parseOrderBy(), l) + sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), l) err = db.Db.Select(&r, sql, args...) return }