From e0786f7f8b6a596dcecdfedcbc07db160ff32a7f Mon Sep 17 00:00:00 2001 From: xing Date: Sat, 25 Feb 2023 23:10:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E7=9B=B8=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/pkg/dao/posts.go | 15 +--- internal/theme/common/index.go | 14 +++- model/condition.go | 54 ++++++------- model/parse.go | 8 +- model/query.go | 136 +++++++++++++-------------------- model/query_test.go | 69 ++++++++--------- model/querycondition.go | 117 ++++++++++++++-------------- model/querycondition_test.go | 22 +++--- 8 files changed, 202 insertions(+), 233 deletions(-) diff --git a/internal/pkg/dao/posts.go b/internal/pkg/dao/posts.go index 310eabe..47ac1b1 100644 --- a/internal/pkg/dao/posts.go +++ b/internal/pkg/dao/posts.go @@ -91,18 +91,9 @@ func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { func SearchPostIds(args ...any) (ids PostIds, err error) { ctx := args[0].(context.Context) - where := args[1].(model.SqlBuilder) - page := args[2].(int) - limit := args[3].(int) - order := args[4].(model.SqlBuilder) - join := args[5].(model.SqlBuilder) - postType := args[6].([]any) - postStatus := args[7].([]any) - res, total, err := model.SimplePagination[models.Posts]( - ctx, where, "ID", - "", page, limit, order, - join, nil, postType, postStatus, - ) + q := args[1].(model.QueryCondition) + q.Fields = "ID" + res, total, err := model.Pagination[models.Posts](ctx, q) for _, posts := range res { ids.Ids = append(ids.Ids, posts.Id) } diff --git a/internal/theme/common/index.go b/internal/theme/common/index.go index 1c03d2e..2730f84 100644 --- a/internal/theme/common/index.go +++ b/internal/theme/common/index.go @@ -56,16 +56,22 @@ func (i *IndexHandle) ParseIndex(parm *IndexParams) (err error) { func (i *IndexHandle) GetIndexData() (posts []models.Posts, totalRaw int, err error) { + q := model.QueryCondition{ + Where: i.Param.Where, + Page: i.Param.Page, + Limit: i.Param.PageSize, + Order: model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, + Join: i.Param.Join, + In: [][]any{i.Param.PostType, i.Param.PostStatus}, + } switch i.Scene { case constraints.Home, constraints.Category, constraints.Tag, constraints.Author: - posts, totalRaw, err = cache.PostLists(i.C, i.Param.CacheKey, i.C, i.Param.Where, i.Param.Page, i.Param.PageSize, - model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, i.Param.Join, i.Param.PostType, i.Param.PostStatus) + posts, totalRaw, err = cache.PostLists(i.C, i.Param.CacheKey, i.C, q) case constraints.Search: - posts, totalRaw, err = cache.SearchPost(i.C, i.Param.CacheKey, i.C, i.Param.Where, i.Param.Page, i.Param.PageSize, - model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, i.Param.Join, i.Param.PostType, i.Param.PostStatus) + posts, totalRaw, err = cache.SearchPost(i.C, i.Param.CacheKey, i.C, q) case constraints.Archive: diff --git a/model/condition.go b/model/condition.go index 7da10d8..f142448 100644 --- a/model/condition.go +++ b/model/condition.go @@ -1,26 +1,26 @@ package model type QueryCondition struct { - where ParseWhere - from string - fields string - group string - order SqlBuilder - join SqlBuilder - having SqlBuilder - page int - limit int - offset int - in [][]any + Where ParseWhere + From string + Fields string + Group string + Order SqlBuilder + Join SqlBuilder + Having SqlBuilder + Page int + Limit int + Offset int + In [][]any } -func Conditions(fns ...Condition) *QueryCondition { - r := &QueryCondition{} +func Conditions(fns ...Condition) QueryCondition { + r := QueryCondition{} for _, fn := range fns { - fn(r) + fn(&r) } - if r.fields == "" { - r.fields = "*" + if r.Fields == "" { + r.Fields = "*" } return r } @@ -29,65 +29,65 @@ type Condition func(c *QueryCondition) func Where(where ParseWhere) Condition { return func(c *QueryCondition) { - c.where = where + c.Where = where } } func Fields(fields string) Condition { return func(c *QueryCondition) { - c.fields = fields + c.Fields = fields } } func From(from string) Condition { return func(c *QueryCondition) { - c.from = from + c.From = from } } func Group(group string) Condition { return func(c *QueryCondition) { - c.group = group + c.Group = group } } func Order(order SqlBuilder) Condition { return func(c *QueryCondition) { - c.order = order + c.Order = order } } func Join(join SqlBuilder) Condition { return func(c *QueryCondition) { - c.join = join + c.Join = join } } func Having(having SqlBuilder) Condition { return func(c *QueryCondition) { - c.having = having + c.Having = having } } func Page(page int) Condition { return func(c *QueryCondition) { - c.page = page + c.Page = page } } func Limit(limit int) Condition { return func(c *QueryCondition) { - c.limit = limit + c.Limit = limit } } func Offset(offset int) Condition { return func(c *QueryCondition) { - c.offset = offset + c.Offset = offset } } func In(in ...[]any) Condition { return func(c *QueryCondition) { - c.in = append(c.in, in...) + c.In = append(c.In, in...) } } diff --git a/model/parse.go b/model/parse.go index 3e690d9..3d9a19f 100644 --- a/model/parse.go +++ b/model/parse.go @@ -56,9 +56,11 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error { // ParseWhere 解析为where条件,支持3种风格,具体用法参照query_test中的 Find 的测试方法 // -// 1. 2个为一组 {{"field1","value1"},{"field2","value2"}} => where field1='value1' and field2='value2' +// 1. 1个为一组 {{"field operator value"}} 为纯字符串条件,不对参数做处理 // -// 2. 3个或4个为一组 {{"field","operator","value"[,"int|float"]}} => where field operator 'string'|int|float +// 2. 2个为一组 {{"field1","value1"},{"field2","value2"}} => where field1='value1' and field2='value2' +// +// 3. 3个或4个为一组 {{"field","operator","value"[,"int|float"]}} => where field operator 'string'|int|float // // {{"a",">","1","int"}} => where 'a'> 1 // @@ -66,7 +68,7 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error { // // 另外如果是操作符为in的话为 {{"field","in",""}} => where field in (?,..) in的条件传给 in参数 // -// 3. 5的倍数为一组{{"and|or","field","operator","value","int|float"}}会忽然掉第一组的and|or +// 4. 5的倍数为一组{{"and|or","field","operator","value","int|float"}}会忽然掉第一组的and|or // // {{"and","field","=","value1","","and","field","=","value2",""}} => where (field = 'value1' and field = 'value2') // diff --git a/model/query.go b/model/query.go index 5a45a36..8d7da6b 100644 --- a/model/query.go +++ b/model/query.go @@ -3,81 +3,54 @@ package model import ( "context" "fmt" + "github.com/fthvgb1/wp-go/helper/number" + str "github.com/fthvgb1/wp-go/helper/strings" "golang.org/x/exp/constraints" "math/rand" "strings" ) -func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { - var rr T - var w string - var args []any - if where != nil { - w, args, err = where.ParseWhere(&in) - if err != nil { - return r, total, err +func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, total int, err error) { + qx := QueryCondition{ + Where: q.Where, + Having: q.Having, + Join: q.Join, + In: q.In, + Group: q.Group, + From: q.From, + } + if q.Group != "" { + qx.Fields = q.Fields + sq, in, er := BuildQuerySql[T](qx) + qx.In = [][]any{in} + if er != nil { + err = er + return + } + qx.From = str.Join("( ", sq, " ) ", "table", number.ToString(rand.Int())) + qx = QueryCondition{ + From: qx.From, + In: qx.In, } } - h := "" - if having != nil { - hh, arg, err := having.ParseWhere(&in) - if err != nil { - return r, total, err - } - args = append(args, arg...) - h = strings.Replace(hh, " where", " having", 1) - } - - n := struct { - N int `db:"n" json:"n"` - }{} - groupBy := "" - if group != "" { - g := strings.Builder{} - g.WriteString(" group by ") - g.WriteString(group) - groupBy = g.String() - } - if having != nil { - tm := map[string]struct{}{} - for _, s := range strings.Split(group, ",") { - tm[s] = struct{}{} - } - for _, ss := range having { - if _, ok := tm[ss[0]]; !ok { - group = fmt.Sprintf("%s,%s", group, ss[0]) - } - } - group = strings.Trim(group, ",") - } - j := join.parseJoin() - if group == "" { - tpx := "select count(*) n from %s %s %s limit 1" - sq := fmt.Sprintf(tpx, rr.Table(), j, w) - err = db.Get(ctx, &n, sq, args...) - } else { - tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" - sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) - err = db.Get(ctx, &n, sq, args...) - } - - if err != nil { + n, err := GetField[T](ctx, "count(*)", qx) + total = str.ToInt[int](n) + if err != nil || total < 1 { return } - if n.N == 0 { - return - } - total = n.N offset := 0 - if page > 1 { - offset = (page - 1) * pageSize + if q.Page > 1 { + offset = (q.Page - 1) * q.Limit } if offset >= total { return } - tp := "select %s from %s %s %s %s %s %s limit %d,%d" - sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) + q.Offset = offset + sq, args, err := BuildQuerySql[T](q) + if err != nil { + return + } err = db.Select(ctx, &r, sq, args...) if err != nil { return @@ -85,11 +58,6 @@ func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fiel return } -func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { - r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...) - return -} - func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { var r T sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) @@ -101,12 +69,12 @@ func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, } func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) { - s, args, err := BuildQuerySql[T](&QueryCondition{ - where: where, - fields: fields, - order: order, - in: in, - limit: 1, + s, args, err := BuildQuerySql[T](QueryCondition{ + Where: where, + Fields: fields, + Order: order, + In: in, + Limit: 1, }) if err != nil { return @@ -136,10 +104,10 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in . } func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { - s, args, err := BuildQuerySql[T](&QueryCondition{ - where: where, - fields: fields, - in: in, + s, args, err := BuildQuerySql[T](QueryCondition{ + Where: where, + Fields: fields, + In: in, }) if err != nil { return @@ -161,16 +129,16 @@ func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { q := QueryCondition{ - where: where, - fields: fields, - group: group, - order: order, - join: join, - having: having, - limit: limit, - in: in, + Where: where, + Fields: fields, + Group: group, + Order: order, + Join: join, + Having: having, + Limit: limit, + In: in, } - s, args, err := BuildQuerySql[T](&q) + s, args, err := BuildQuerySql[T](q) if err != nil { return } diff --git a/model/query_test.go b/model/query_test.go index ee38a57..a44f726 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -3,8 +3,7 @@ package model import ( "context" "database/sql" - "github.com/fthvgb1/wp-go/helper/number" - "github.com/fthvgb1/wp-go/helper/slice" + "fmt" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "reflect" @@ -108,7 +107,15 @@ func init() { if err != nil { panic(err) } - glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) + glob = NewSqlxQuery(db, NewUniversalDb(func(ctx2 context.Context, a any, s string, a2 ...any) error { + x := FormatSql(s, a2...) + fmt.Println(x) + return glob.Selects(ctx2, a, s, a2...) + }, func(ctx2 context.Context, a any, s string, a2 ...any) error { + x := FormatSql(s, a2...) + fmt.Println(x) + return glob.Gets(ctx2, a, s, a2...) + })) ddb = db InitDB(glob) } @@ -456,65 +463,55 @@ func TestSimpleFind(t *testing.T) { } } -func TestSimplePagination(t *testing.T) { +func Test_pagination(t *testing.T) { type args struct { - where ParseWhere - fields string - group string - page int - pageSize int - order SqlBuilder - join SqlBuilder - having SqlBuilder - in [][]any + db dbQuery + ctx context.Context + q QueryCondition } - tests := []struct { + type testCase[T Model] struct { name string args args - wantR []post + wantR []T wantTotal int wantErr bool - }{ + } + tests := []testCase[post]{ { name: "t1", args: args{ - where: SqlBuilder{ - {"ID", "in", ""}, + db: glob, + ctx: ctx, + q: QueryCondition{ + Fields: "post_type,count(*) ID", + Group: "post_type", + Having: SqlBuilder{{"ID", ">", "1", "int"}}, }, - fields: "*", - group: "", - page: 1, - pageSize: 5, - order: nil, - join: nil, - having: nil, - in: [][]any{slice.ToAnySlice(number.Range(431, 440, 1))}, }, wantR: func() (r []post) { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...) - if err != nil && err != sql.ErrNoRows { + + err := glob.Selects(ctx, &r, "select post_type,count(*) ID from wp_posts group by post_type having `ID`> 1") + if err != nil { panic(err) - } else if err == sql.ErrNoRows { - err = nil } - return + return r }(), - wantTotal: 10, + wantTotal: 7, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotR, gotTotal, err := SimplePagination[post](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.page, tt.args.pageSize, tt.args.order, tt.args.join, tt.args.having, tt.args.in...) + gotR, gotTotal, err := pagination[post](tt.args.db, tt.args.ctx, tt.args.q) if (err != nil) != tt.wantErr { - t.Errorf("SimplePagination() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("pagination() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("SimplePagination() gotR = %v, want %v", gotR, tt.wantR) + t.Errorf("pagination() gotR = %v, want %v", gotR, tt.wantR) } if gotTotal != tt.wantTotal { - t.Errorf("SimplePagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal) + t.Errorf("pagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal) } }) } diff --git a/model/querycondition.go b/model/querycondition.go index 8cb4a59..ec3c8c9 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -12,7 +12,7 @@ import ( // Finds 比 Find 多一个offset // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { +func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) { r, err = finds[T](globalBb, ctx, q) return } @@ -20,12 +20,12 @@ func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { // FindFromDB 同 Finds 使用指定 db 查询 // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { +func FindFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) { r, err = finds[T](db, ctx, q) return } -func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { +func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) { sq, args, err := BuildQuerySql[T](q) if err != nil { return @@ -34,17 +34,17 @@ func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, return } -func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { +func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { i := 1 var rr []T var total int var offset int for { if 1 == i { - rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q) } else { - q.offset = offset - q.limit = perLimit + q.Offset = offset + q.Limit = perLimit rr, err = finds[T](db, ctx, q) } offset += perLimit @@ -63,7 +63,7 @@ func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryC // ChunkFind 分片查询并直接返回所有结果 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { +func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { r, err = chunkFind[T](globalBb, ctx, perLimit, q) return } @@ -71,7 +71,7 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r // ChunkFindFromDB 同 ChunkFind // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { +func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { r, err = chunkFind[T](db, ctx, perLimit, q) return } @@ -79,7 +79,7 @@ func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q * // Chunk 分片查询并函数过虑返回新类型的切片 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { +func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { r, err = chunk(globalBb, ctx, perLimit, fn, q) return } @@ -87,12 +87,12 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R // ChunkFromDB 同 Chunk // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { +func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { r, err = chunk(db, ctx, perLimit, fn, q) return } -func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { +func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { i := 1 var rr []T var count int @@ -100,10 +100,10 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun var offset int for { if 1 == i { - rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q) } else { - q.offset = offset - q.limit = perLimit + q.Offset = offset + q.Limit = perLimit rr, err = finds[T](db, ctx, q) } offset += perLimit @@ -125,28 +125,28 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun return } -// Pagination 同 SimplePagination +// Pagination 同 // -// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 -func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { - return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) +// Condition 中可使用 Where Fields From Group Having Join Order Page Limit In 函数 +func Pagination[T Model](ctx context.Context, q QueryCondition) ([]T, int, error) { + return pagination[T](globalBb, ctx, q) } // PaginationFromDB 同 Pagination 方便多个db使用 // // Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 -func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) { - return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) +func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) ([]T, int, error) { + return pagination[T](db, ctx, q) } -func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) { +func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q QueryCondition) ([]T, error) { return column[V, T](globalBb, ctx, fn, q) } -func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { +func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) { return column[V, T](db, ctx, fn, q) } -func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { +func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) { res, err := finds[V](db, ctx, q) if err != nil { return nil, err @@ -155,13 +155,13 @@ func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool return } -func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) { +func GetField[T Model](ctx context.Context, field string, q QueryCondition) (r string, err error) { r, err = getField[T](globalBb, ctx, field, q) return } -func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { - if q.fields == "" || q.fields == "*" { - q.fields = field +func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) { + if q.Fields == "" || q.Fields == "*" { + q.Fields = field } res, err := getToStringMap[T](db, ctx, q) if err != nil { @@ -174,11 +174,11 @@ func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCo } return } -func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { +func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) { return getField[T](db, ctx, field, q) } -func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { +func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) { rawSql, in, err := BuildQuerySql[T](q) if err != nil { return nil, err @@ -187,12 +187,12 @@ func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) err = db.Get(ctx, &r, rawSql, in...) return } -func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) { +func GetToStringMap[T Model](ctx context.Context, q QueryCondition) (r map[string]string, err error) { r, err = getToStringMap[T](globalBb, ctx, q) return } -func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { +func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) { rawSql, in, err := BuildQuerySql[T](q) if err != nil { return nil, err @@ -202,33 +202,33 @@ func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition return } -func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { +func FindToStringMap[T Model](ctx context.Context, q QueryCondition) (r []map[string]string, err error) { r, err = findToStringMap[T](globalBb, ctx, q) return } -func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { +func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) { r, err = findToStringMap[T](db, ctx, q) return } -func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { +func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) { r, err = getToStringMap[T](db, ctx, q) return } -func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) { +func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error) { var rr T w := "" - if q.where != nil { - w, args, err = q.where.ParseWhere(&q.in) + if q.Where != nil { + w, args, err = q.Where.ParseWhere(&q.In) if err != nil { return } } h := "" - if q.having != nil { - hh, arg, er := q.having.ParseWhere(&q.in) + if q.Having != nil { + hh, arg, er := q.Having.ParseWhere(&q.In) if er != nil { err = er return @@ -236,32 +236,37 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) args = append(args, arg...) h = strings.Replace(hh, " where", " having", 1) } + if len(args) == 0 && len(q.In) > 0 { + for _, antes := range q.In { + args = append(args, antes...) + } + } - j := q.join.parseJoin() + j := q.Join.parseJoin() groupBy := "" - if q.group != "" { + if q.Group != "" { g := strings.Builder{} g.WriteString(" group by ") - g.WriteString(q.group) + g.WriteString(q.Group) groupBy = g.String() } tp := "select %s from %s %s %s %s %s %s %s" l := "" table := rr.Table() - if q.from != "" { - table = q.from + if q.From != "" { + table = q.From } - if q.limit > 0 { - l = fmt.Sprintf(" limit %d", q.limit) + if q.Limit > 0 { + l = fmt.Sprintf(" limit %d", q.Limit) } - if q.offset > 0 { - l = fmt.Sprintf(" %s offset %d", l, q.offset) + if q.Offset > 0 { + l = fmt.Sprintf(" %s offset %d", l, q.Offset) } - r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l) + r = fmt.Sprintf(tp, q.Fields, table, j, w, groupBy, h, q.Order.parseOrderBy(), l) return } -func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) { +func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) { s, args, err := BuildQuerySql[T](q) if err != nil { return @@ -275,22 +280,22 @@ func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryC return } -func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error { +func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) error { return findScanner[T](db, ctx, fn, q) } -func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error { +func FindScanner[T Model](ctx context.Context, fn func(T), q QueryCondition) error { return findScanner[T](globalBb, ctx, fn, q) } -func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) { +func Gets[T Model](ctx context.Context, q QueryCondition) (T, error) { return gets[T](globalBb, ctx, q) } -func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) { +func GetsFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (T, error) { return gets[T](db, ctx, q) } -func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) { +func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) { s, args, err := BuildQuerySql[T](q) if err != nil { return diff --git a/model/querycondition_test.go b/model/querycondition_test.go index 4bcc7f0..636489b 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -15,7 +15,7 @@ import ( func TestFinds(t *testing.T) { type args struct { ctx context.Context - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string @@ -66,7 +66,7 @@ func TestChunkFind(t *testing.T) { type args struct { ctx context.Context perLimit int - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string @@ -118,7 +118,7 @@ func TestChunk(t *testing.T) { ctx context.Context perLimit int fn func(rows T) (R, bool) - q *QueryCondition + q QueryCondition } type testCase[T Model, R any] struct { name string @@ -179,7 +179,7 @@ func TestChunk(t *testing.T) { func TestPagination(t *testing.T) { type args struct { ctx context.Context - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string @@ -236,7 +236,7 @@ func TestColumn(t *testing.T) { type args[V Model, T any] struct { ctx context.Context fn func(V) (T, bool) - q *QueryCondition + q QueryCondition } type testCase[V Model, T any] struct { name string @@ -333,7 +333,7 @@ func Test_getField(t *testing.T) { db := glob field := "count(*)" q := Conditions() - wantR := "386" + wantR := "387" wantErr := false t.Run(name, func(t *testing.T) { gotR, err := getField[options](db, ctx, field, q) @@ -352,7 +352,7 @@ func Test_getToStringMap(t *testing.T) { type args struct { db dbQuery ctx context.Context - q *QueryCondition + q QueryCondition } tests := []struct { name string @@ -407,7 +407,7 @@ func Test_findToStringMap(t *testing.T) { type args struct { db dbQuery ctx context.Context - q *QueryCondition + q QueryCondition } tests := []struct { name string @@ -482,7 +482,7 @@ func Test_findScanner(t *testing.T) { db dbQuery ctx context.Context fn func(T) - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string @@ -545,7 +545,7 @@ func Test_gets(t *testing.T) { type args struct { db dbQuery ctx context.Context - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string @@ -583,7 +583,7 @@ func Test_finds(t *testing.T) { type args struct { db dbQuery ctx context.Context - q *QueryCondition + q QueryCondition } type testCase[T Model] struct { name string