From 1ecfa19fd44845aefd3014679d9da01974d9c07a Mon Sep 17 00:00:00 2001 From: xing Date: Wed, 22 Feb 2023 16:53:53 +0800 Subject: [PATCH] =?UTF-8?q?dbquery=E6=B3=9B=E5=9E=8B=E5=8C=96=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=EF=BC=8C=E4=BC=BC=E4=B9=8E=E6=80=A7=E8=83=BD=E5=92=8C?= =?UTF-8?q?=E4=B9=8B=E5=89=8D=E4=B9=9F=E5=B7=AE=E4=B8=8D=E5=A4=9A=E7=9A=84?= =?UTF-8?q?=E6=A0=B7=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/condition.go | 7 + model/model.go | 13 +- model/query.go | 188 +----------- model/query_test.go | 551 ++++++--------------------------- model/querycondition.go | 254 +--------------- model/querycondition_test.go | 570 ----------------------------------- model/sqxquery.go | 34 ++- model/universal.go | 18 +- 8 files changed, 143 insertions(+), 1492 deletions(-) diff --git a/model/condition.go b/model/condition.go index da4ef85..7da10d8 100644 --- a/model/condition.go +++ b/model/condition.go @@ -2,6 +2,7 @@ package model type QueryCondition struct { where ParseWhere + from string fields string group string order SqlBuilder @@ -37,6 +38,12 @@ func Fields(fields string) Condition { } } +func From(from string) Condition { + return func(c *QueryCondition) { + c.from = from + } +} + func Group(group string) Condition { return func(c *QueryCondition) { c.group = group diff --git a/model/model.go b/model/model.go index c2d83a8..f611489 100644 --- a/model/model.go +++ b/model/model.go @@ -5,13 +5,14 @@ import ( ) var _ ParseWhere = SqlBuilder{} -var globalBb dbQuery +var globalBb dbQuery[Model] -func InitDB(db dbQuery) { +func InitDB(db dbQuery[Model]) { globalBb = db } -type QueryFn func(context.Context, any, string, ...any) error +type QuerySelect[T any] func(context.Context, string, ...any) ([]T, error) +type QueryGet[T any] func(context.Context, string, ...any) (T, error) type Model interface { PrimaryKey() string @@ -22,9 +23,9 @@ type ParseWhere interface { ParseWhere(*[][]any) (string, []any, error) } -type dbQuery interface { - Select(context.Context, any, string, ...any) error - Get(context.Context, any, string, ...any) error +type dbQuery[T any] interface { + Select(context.Context, string, ...any) ([]T, error) + Get(context.Context, string, ...any) (T, error) } type SqlBuilder [][]string diff --git a/model/query.go b/model/query.go index 5a45a36..dd51ef6 100644 --- a/model/query.go +++ b/model/query.go @@ -1,185 +1,25 @@ package model -import ( - "context" - "fmt" - "golang.org/x/exp/constraints" - "math/rand" - "strings" -) - -func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { - var rr T - var w string - var args []any - if where != nil { - w, args, err = where.ParseWhere(&in) - if err != nil { - return r, total, err - } - } - - h := "" - if having != nil { - hh, arg, err := having.ParseWhere(&in) - if err != nil { - return r, total, err - } - args = append(args, arg...) - h = strings.Replace(hh, " where", " having", 1) - } - - n := struct { - N int `db:"n" json:"n"` - }{} - groupBy := "" - if group != "" { - g := strings.Builder{} - g.WriteString(" group by ") - g.WriteString(group) - groupBy = g.String() - } - if having != nil { - tm := map[string]struct{}{} - for _, s := range strings.Split(group, ",") { - tm[s] = struct{}{} - } - for _, ss := range having { - if _, ok := tm[ss[0]]; !ok { - group = fmt.Sprintf("%s,%s", group, ss[0]) - } - } - group = strings.Trim(group, ",") - } - j := join.parseJoin() - if group == "" { - tpx := "select count(*) n from %s %s %s limit 1" - sq := fmt.Sprintf(tpx, rr.Table(), j, w) - err = db.Get(ctx, &n, sq, args...) - } else { - tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" - sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) - err = db.Get(ctx, &n, sq, args...) - } +import "context" +func finds[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) { + s, args, err := BuildQuerySql[T](q) if err != nil { - return + return nil, err } - if n.N == 0 { - return - } - total = n.N - offset := 0 - if page > 1 { - offset = (page - 1) * pageSize - } - if offset >= total { - return - } - tp := "select %s from %s %s %s %s %s %s limit %d,%d" - sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) - err = db.Select(ctx, &r, sq, args...) - if err != nil { - return - } - return + return db.Select(ctx, s, args...) } -func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { - r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...) - return -} - -func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { - var r T - sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) - err := globalBb.Get(ctx, &r, sq, id) +func scanners[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) { + s, args, err := BuildQuerySql[T](q) if err != nil { - return r, err + return nil, err } - return r, nil -} - -func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) { - s, args, err := BuildQuerySql[T](&QueryCondition{ - where: where, - fields: fields, - order: order, - in: in, - limit: 1, - }) - if err != nil { - return - } - err = globalBb.Get(ctx, &r, s, args...) - return -} - -func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) { - var r T - var w string - var args []any - var err error - if where != nil { - w, args, err = where.ParseWhere(&in) - if err != nil { - return r, err - } - } - tp := "select %s from %s %s order by %s desc limit 1" - sq := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) - err = globalBb.Get(ctx, &r, sq, args...) - if err != nil { - return r, err - } - return r, nil -} - -func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { - s, args, err := BuildQuerySql[T](&QueryCondition{ - where: where, - fields: fields, - in: in, - }) - if err != nil { - return - } - err = globalBb.Select(ctx, &r, s, args...) - return r, nil -} - -func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error) { + ctx = context.WithValue(ctx, "handle=>", "scanner") var r []T - var rr T - sql = strings.Replace(sql, "{table}", rr.Table(), -1) - err := globalBb.Select(ctx, &r, sql, params...) - if err != nil { - return r, err - } - return r, nil -} - -func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { - q := QueryCondition{ - where: where, - fields: fields, - group: group, - order: order, - join: join, - having: having, - limit: limit, - in: in, - } - s, args, err := BuildQuerySql[T](&q) - if err != nil { - return - } - err = globalBb.Select(ctx, &r, s, args...) - return -} - -func Get[T Model](ctx context.Context, sql string, params ...any) (r T, err error) { - sql = strings.Replace(sql, "{table}", r.Table(), -1) - err = globalBb.Get(ctx, &r, sql, params...) - return + ctx = context.WithValue(ctx, "fn", func(t T) { + r = append(r, t) + }) + _, err = db.Select(ctx, s, args...) + return r, err } diff --git a/model/query_test.go b/model/query_test.go index a47482a..e53fd3d 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -2,517 +2,136 @@ package model import ( "context" - "database/sql" - "github.com/fthvgb1/wp-go/helper/number" - "github.com/fthvgb1/wp-go/helper/slice" + "github.com/fthvgb1/wp-go/safety" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "reflect" + "sync" "testing" - "time" ) -type post struct { - Id uint64 `gorm:"column:ID" db:"ID" json:"ID" form:"ID"` - PostAuthor uint64 `gorm:"column:post_author" db:"post_author" json:"post_author" form:"post_author"` - PostDate time.Time `gorm:"column:post_date" db:"post_date" json:"post_date" form:"post_date"` - PostDateGmt time.Time `gorm:"column:post_date_gmt" db:"post_date_gmt" json:"post_date_gmt" form:"post_date_gmt"` - PostContent string `gorm:"column:post_content" db:"post_content" json:"post_content" form:"post_content"` - PostTitle string `gorm:"column:post_title" db:"post_title" json:"post_title" form:"post_title"` - PostExcerpt string `gorm:"column:post_excerpt" db:"post_excerpt" json:"post_excerpt" form:"post_excerpt"` - PostStatus string `gorm:"column:post_status" db:"post_status" json:"post_status" form:"post_status"` - CommentStatus string `gorm:"column:comment_status" db:"comment_status" json:"comment_status" form:"comment_status"` - PingStatus string `gorm:"column:ping_status" db:"ping_status" json:"ping_status" form:"ping_status"` - PostPassword string `gorm:"column:post_password" db:"post_password" json:"post_password" form:"post_password"` - PostName string `gorm:"column:post_name" db:"post_name" json:"post_name" form:"post_name"` - ToPing string `gorm:"column:to_ping" db:"to_ping" json:"to_ping" form:"to_ping"` - Pinged string `gorm:"column:pinged" db:"pinged" json:"pinged" form:"pinged"` - PostModified time.Time `gorm:"column:post_modified" db:"post_modified" json:"post_modified" form:"post_modified"` - PostModifiedGmt time.Time `gorm:"column:post_modified_gmt" db:"post_modified_gmt" json:"post_modified_gmt" form:"post_modified_gmt"` - PostContentFiltered string `gorm:"column:post_content_filtered" db:"post_content_filtered" json:"post_content_filtered" form:"post_content_filtered"` - PostParent uint64 `gorm:"column:post_parent" db:"post_parent" json:"post_parent" form:"post_parent"` - Guid string `gorm:"column:guid" db:"guid" json:"guid" form:"guid"` - MenuOrder int `gorm:"column:menu_order" db:"menu_order" json:"menu_order" form:"menu_order"` - PostType string `gorm:"column:post_type" db:"post_type" json:"post_type" form:"post_type"` - PostMimeType string `gorm:"column:post_mime_type" db:"post_mime_type" json:"post_mime_type" form:"post_mime_type"` - CommentCount int64 `gorm:"column:comment_count" db:"comment_count" json:"comment_count" form:"comment_count"` -} - -type user struct { - Id uint64 `gorm:"column:ID" db:"ID" json:"ID"` - UserLogin string `gorm:"column:user_login" db:"user_login" json:"user_login"` - UserPass string `gorm:"column:user_pass" db:"user_pass" json:"user_pass"` - UserNicename string `gorm:"column:user_nicename" db:"user_nicename" json:"user_nicename"` - UserEmail string `gorm:"column:user_email" db:"user_email" json:"user_email"` - UserUrl string `gorm:"column:user_url" db:"user_url" json:"user_url"` - UserRegistered time.Time `gorm:"column:user_registered" db:"user_registered" json:"user_registered"` - UserActivationKey string `gorm:"column:user_activation_key" db:"user_activation_key" json:"user_activation_key"` - UserStatus int `gorm:"column:user_status" db:"user_status" json:"user_status"` - DisplayName string `gorm:"column:display_name" db:"display_name" json:"display_name"` -} - -type termTaxonomy struct { - TermTaxonomyId uint64 `gorm:"column:term_taxonomy_id" db:"term_taxonomy_id" json:"term_taxonomy_id" form:"term_taxonomy_id"` - TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"` - Taxonomy string `gorm:"column:taxonomy" db:"taxonomy" json:"taxonomy" form:"taxonomy"` - Description string `gorm:"column:description" db:"description" json:"description" form:"description"` - Parent uint64 `gorm:"column:parent" db:"parent" json:"parent" form:"parent"` - Count int64 `gorm:"column:count" db:"count" json:"count" form:"count"` -} - -type terms struct { - TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"` - Name string `gorm:"column:name" db:"name" json:"name" form:"name"` - Slug string `gorm:"column:slug" db:"slug" json:"slug" form:"slug"` - TermGroup int64 `gorm:"column:term_group" db:"term_group" json:"term_group" form:"term_group"` -} - -func (t terms) PrimaryKey() string { - return "term_id" -} -func (t terms) Table() string { - return "wp_terms" -} - -func (w termTaxonomy) PrimaryKey() string { - return "term_taxonomy_id" -} - -func (w termTaxonomy) Table() string { - return "wp_term_taxonomy" -} - -func (u user) Table() string { - return "wp_users" -} - -func (u user) PrimaryKey() string { - return "ID" -} - -func (p post) PrimaryKey() string { - return "ID" -} - -func (p post) Table() string { - return "wp_posts" -} - var ctx = context.Background() -var glob *SqlxQuery +var glob = safety.NewMap[string, dbQuery[Model]]() +var dbMap = sync.Map{} + +var sq *sqlx.DB + +func anyDb[T Model]() *SqlxQuery[T] { + var a T + db, ok := dbMap.Load(a.Table()) + if ok { + return db.(*SqlxQuery[T]) + } + dbb := NewSqlxQuery[T](sq, UniversalDb[T]{nil, nil}) + dbMap.Store(a.Table(), dbb) + return dbb +} func init() { db, err := sqlx.Open("mysql", "root:root@tcp(192.168.66.47:3306)/wordpress?charset=utf8mb4&parseTime=True&loc=Local") if err != nil { panic(err) } - glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) - InitDB(glob) -} -func TestFind(t *testing.T) { - type args struct { - where ParseWhere - fields string - group string - order SqlBuilder - join SqlBuilder - having SqlBuilder - limit int - in [][]any - } - type posts struct { - post - N int `db:"n"` - } - tests := []struct { - name string - args args - wantR []posts - wantErr bool - }{ - { - name: "in,orderBy", - args: args{ - where: SqlBuilder{{ - "post_status", "publish", - }, {"ID", "in", ""}}, - fields: "*", - group: "", - order: SqlBuilder{{"ID", "desc"}}, - join: nil, - having: nil, - limit: 0, - in: [][]any{{1, 2, 3, 4}}, - }, - wantR: func() []posts { - r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where post_status='publish' and ID in (1,2,3,4) order by ID desc") - if err != nil { - panic(err) - } - return r - }(), - wantErr: false, - }, - { - name: "or", - args: args{ - where: SqlBuilder{{ - "and", "ID", "=", "1", "int", - }, {"or", "ID", "=", "2", "int"}}, - fields: "*", - group: "", - order: nil, - join: nil, - having: nil, - limit: 0, - in: nil, - }, - wantR: func() []posts { - r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID=1 or ID=2)") - if err != nil { - panic(err) - } - return r - }(), - }, - { - name: "group,having", - args: args{ - where: SqlBuilder{ - {"ID", "<", "1000", "int"}, - }, - fields: "post_status,count(*) n", - group: "post_status", - order: nil, - join: nil, - having: SqlBuilder{ - {"n", ">", "1"}, - }, - limit: 0, - in: nil, - }, - wantR: func() []posts { - r, err := Select[posts](ctx, "select post_status,count(*) n from "+post{}.Table()+" where ID<1000 group by post_status having n>1") - if err != nil { - panic(err) - } - return r - }(), - }, - { - name: "or、多个in", - args: args{ - where: SqlBuilder{ - {"and", "ID", "in", "", "", "or", "ID", "in", "", ""}, - {"or", "post_status", "=", "publish", "", "and", "post_status", "=", "closed", ""}, - }, - fields: "*", - group: "", - order: nil, - join: nil, - having: nil, - limit: 0, - in: [][]any{{1, 2, 3}, {4, 5, 6}}, - }, - wantR: func() []posts { - r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID in (1,2,3) or ID in (4,5,6)) or (post_status='publish' and post_status='closed')") - if err != nil { - panic(err) - } - return r - }(), - }, - { - name: "all", - args: args{ - where: SqlBuilder{ - {"b.user_login", "in", ""}, - {"and", "a.post_type", "=", "post", "", "or", "a.post_type", "=", "page", ""}, - {"a.comment_count", ">", "0", "int"}, - {"a.post_status", "publish"}, - {"e.name", "in", ""}, - {"d.taxonomy", "category"}, - }, - fields: "post_author,count(*) n", - group: "a.post_author", - order: SqlBuilder{{"n", "desc"}}, - join: SqlBuilder{ - {"a", "left join", user{}.Table() + " b", "a.post_author=b.ID"}, - {"left join", "wp_term_relationships c", "a.Id=c.object_id"}, - {"left join", termTaxonomy{}.Table() + " d", "c.term_taxonomy_id=d.term_taxonomy_id"}, - {"left join", terms{}.Table() + " e", "d.term_id=e.term_id"}, - }, - having: SqlBuilder{{"n", ">", "0", "int"}}, - limit: 10, - in: [][]any{{"test", "test2"}, {"web", "golang", "php"}}, - }, - wantR: func() []posts { - r, err := Select[posts](ctx, "select post_author,count(*) n from wp_posts a left join wp_users b on a.post_author=b.ID left join wp_term_relationships c on a.Id=c.object_id left join wp_term_taxonomy d on c.term_taxonomy_id=d.term_taxonomy_id left join wp_terms e on d.term_id=e.term_id where b.user_login in ('test','test2') and b.user_status=0 and (a.post_type='post' or a.post_type='page') and a.comment_count>0 and a.post_status='publish' and e.name in ('web','golang','php') and d.taxonomy='category' group by post_author having n > 0 order by n desc limit 10") - if err != nil { - panic(err) - } - return r - }(), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := Find[posts](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.order, tt.args.join, tt.args.having, tt.args.limit, tt.args.in...) - if (err != nil) != tt.wantErr { - t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("Find() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } + sq = db + //glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) + } -func TestFindOneById(t *testing.T) { - type args struct { - id int +func Test_selects(t *testing.T) { + type args[T Model] struct { + db dbQuery[T] + ctx context.Context + q *QueryCondition } - - tests := []struct { + type testCase[T Model] struct { name string - args args - want post + args args[T] + want []T wantErr bool - }{ + } + tests := []testCase[options]{ { name: "t1", - args: args{ - 1, + args: args[options]{ + anyDb[options](), + ctx, + Conditions(Where(SqlBuilder{{"option_name", "blogname"}})), }, - want: func() post { - r, err := Get[post](ctx, "select * from "+post{}.Table()+" where ID=?", 1) - if err != nil && err != sql.ErrNoRows { - panic(err) - } else if err == sql.ErrNoRows { - err = nil - } - return r - }(), - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := FindOneById[post](ctx, tt.args.id) - if err == sql.ErrNoRows { - err = nil - } + got, err := finds[options](tt.args.db, tt.args.ctx, tt.args.q) if (err != nil) != tt.wantErr { - t.Errorf("FindOneById() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("finds() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("FindOneById() got = %v, want %v", got, tt.want) + t.Errorf("finds() got = %v, want %v", got, tt.want) } }) } } -func TestFirstOne(t *testing.T) { - type args struct { - where ParseWhere - fields string - order SqlBuilder - in [][]any +func BenchmarkSelectXX(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := finds[options](anyDb[options](), ctx, Conditions()) + if err != nil { + panic(err) + } } - tests := []struct { +} +func BenchmarkScannerXX(b *testing.B) { + for i := 0; i < b.N; i++ { + + _, err := scanners[options](anyDb[options](), ctx, Conditions()) + if err != nil { + panic(err) + } + } +} + +type options struct { + OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"` + OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"` + OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"` + Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"` +} + +func (w options) PrimaryKey() string { + return "option_id" +} + +func (w options) Table() string { + return "wp_options" +} + +func Test_scanners(t *testing.T) { + type args[T Model] struct { + db dbQuery[T] + ctx context.Context + q *QueryCondition + } + type testCase[T Model] struct { name string - args args - want post + args args[T] wantErr bool - }{ + } + tests := []testCase[options]{ { name: "t1", - args: args{ - where: SqlBuilder{{"post_status", "publish"}}, - fields: "*", - order: SqlBuilder{{"ID", "desc"}}, - in: nil, + args: args[options]{ + anyDb[options](), + ctx, + Conditions(Where(SqlBuilder{{"option_name", "blogname"}})), }, - wantErr: false, - want: func() post { - r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by ID desc limit 1") - if err != nil && err != sql.ErrNoRows { - panic(err) - } else if err == sql.ErrNoRows { - err = nil - } - return r - }(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...) - if (err != nil) != tt.wantErr { - t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("FirstOne() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestLastOne(t *testing.T) { - type args struct { - where ParseWhere - fields string - in [][]any - } - tests := []struct { - name string - args args - want post - wantErr bool - }{ - { - name: "t1", - args: args{ - where: SqlBuilder{{ - "post_status", "publish", - }}, - fields: "*", - in: nil, - }, - want: func() post { - r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by "+post{}.PrimaryKey()+" desc limit 1") - if err != nil { - panic(err) - } - return r - }(), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := LastOne[post](ctx, tt.args.where, tt.args.fields, tt.args.in...) - if (err != nil) != tt.wantErr { - t.Errorf("LastOne() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("LastOne() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSimpleFind(t *testing.T) { - type args struct { - where ParseWhere - fields string - in [][]any - } - tests := []struct { - name string - args args - want []post - wantErr bool - }{ - { - name: "t1", - args: args{ - where: SqlBuilder{ - {"ID", "in", ""}, - }, - fields: "*", - in: [][]any{{1, 2}}, - }, - want: func() (r []post) { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?)", 1, 2) - if err != nil && err != sql.ErrNoRows { - panic(err) - } else if err == sql.ErrNoRows { - err = nil - } - return - }(), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := SimpleFind[post](ctx, tt.args.where, tt.args.fields, tt.args.in...) - if (err != nil) != tt.wantErr { - t.Errorf("SimpleFind() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("SimpleFind() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSimplePagination(t *testing.T) { - type args struct { - where ParseWhere - fields string - group string - page int - pageSize int - order SqlBuilder - join SqlBuilder - having SqlBuilder - in [][]any - } - tests := []struct { - name string - args args - wantR []post - wantTotal int - wantErr bool - }{ - { - name: "t1", - args: args{ - where: SqlBuilder{ - {"ID", "in", ""}, - }, - fields: "*", - group: "", - page: 1, - pageSize: 5, - order: nil, - join: nil, - having: nil, - in: [][]any{slice.ToAnySlice(number.Range(431, 440, 1))}, - }, - wantR: func() (r []post) { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...) - if err != nil && err != sql.ErrNoRows { - panic(err) - } else if err == sql.ErrNoRows { - err = nil - } - return - }(), - wantTotal: 10, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, gotTotal, err := SimplePagination[post](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.page, tt.args.pageSize, tt.args.order, tt.args.join, tt.args.having, tt.args.in...) - if (err != nil) != tt.wantErr { - t.Errorf("SimplePagination() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("SimplePagination() gotR = %v, want %v", gotR, tt.wantR) - } - if gotTotal != tt.wantTotal { - t.Errorf("SimplePagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal) + if _, err := scanners[options](tt.args.db, tt.args.ctx, tt.args.q); (err != nil) != tt.wantErr { + t.Errorf("scanners() error = %v, wantErr %v", err, tt.wantErr) } }) } diff --git a/model/querycondition.go b/model/querycondition.go index 31c4514..c11f5b9 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -1,222 +1,10 @@ package model import ( - "context" - "database/sql" - "errors" "fmt" - "github.com/fthvgb1/wp-go/helper/slice" "strings" ) -// Finds 比 Find 多一个offset -// -// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { - r, err = finds[T](globalBb, ctx, q) - return -} - -// FindFromDB 同 Finds 使用指定 db 查询 -// -// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { - r, err = finds[T](db, ctx, q) - return -} - -func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { - sq, args, err := BuildQuerySql[T](q) - if err != nil { - return - } - err = db.Select(ctx, &r, sq, args...) - return -} - -func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { - i := 1 - var rr []T - var total int - var offset int - for { - if 1 == i { - rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) - } else { - q.offset = offset - q.limit = perLimit - rr, err = finds[T](db, ctx, q) - } - offset += perLimit - if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { - return - } - r = append(r, rr...) - if len(r) >= total { - break - } - i++ - } - return -} - -// ChunkFind 分片查询并直接返回所有结果 -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { - r, err = chunkFind[T](globalBb, ctx, perLimit, q) - return -} - -// ChunkFindFromDB 同 ChunkFind -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { - r, err = chunkFind[T](db, ctx, perLimit, q) - return -} - -// Chunk 分片查询并函数过虑返回新类型的切片 -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { - r, err = chunk(globalBb, ctx, perLimit, fn, q) - return -} - -// ChunkFromDB 同 Chunk -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { - r, err = chunk(db, ctx, perLimit, fn, q) - return -} - -func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { - i := 1 - var rr []T - var count int - var total int - var offset int - for { - if 1 == i { - rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) - } else { - q.offset = offset - q.limit = perLimit - rr, err = finds[T](db, ctx, q) - } - offset += perLimit - if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { - return - } - for _, t := range rr { - v, ok := fn(t) - if ok { - r = append(r, v) - } - } - count += len(rr) - if count >= total { - break - } - i++ - } - return -} - -// Pagination 同 SimplePagination -// -// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 -func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { - return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) -} - -// PaginationFromDB 同 Pagination 方便多个db使用 -// -// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 -func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) { - return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) -} - -func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) { - return column[V, T](globalBb, ctx, fn, q) -} -func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { - return column[V, T](db, ctx, fn, q) -} - -func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { - res, err := finds[V](db, ctx, q) - if err != nil { - return nil, err - } - r = slice.FilterAndMap(res, fn) - return -} - -func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) { - r, err = getField[T](globalBb, ctx, field, q) - return -} -func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { - if q.fields == "" || q.fields == "*" { - q.fields = field - } - res, err := getToStringMap[T](db, ctx, q) - if err != nil { - return - } - f := strings.Split(field, " ") - r, ok := res[f[len(f)-1]] - if !ok { - err = errors.New("not exists") - } - return -} -func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { - return getField[T](db, ctx, field, q) -} - -func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { - rawSql, in, err := BuildQuerySql[T](q) - if err != nil { - return nil, err - } - ctx = context.WithValue(ctx, "handle=>", "string") - err = db.Get(ctx, &r, rawSql, in...) - return -} -func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) { - r, err = getToStringMap[T](globalBb, ctx, q) - return -} - -func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { - rawSql, in, err := BuildQuerySql[T](q) - if err != nil { - return nil, err - } - ctx = context.WithValue(ctx, "handle=>", "string") - err = db.Select(ctx, &r, rawSql, in...) - return -} - -func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { - r, err = findToStringMap[T](globalBb, ctx, q) - return -} - -func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { - r, err = findToStringMap[T](db, ctx, q) - return -} - -func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { - r, err = getToStringMap[T](db, ctx, q) - return -} - func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) { var rr T w := "" @@ -254,44 +42,10 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) if q.offset > 0 { l = fmt.Sprintf(" %s offset %d", l, q.offset) } - r = fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) - return -} - -func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) { - s, args, err := BuildQuerySql[T](q) - if err != nil { - return + table := rr.Table() + if q.from != "" { + table = q.from } - ctx = context.WithValue(ctx, "handle=>", "scanner") - var v T - ctx = context.WithValue(ctx, "fn", func(v any) { - fn(*(v.(*T))) - }) - err = db.Select(ctx, &v, s, args...) - return -} - -func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error { - return findScanner[T](db, ctx, fn, q) -} - -func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error { - return findScanner[T](globalBb, ctx, fn, q) -} - -func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) { - return gets[T](globalBb, ctx, q) -} -func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) { - return gets[T](db, ctx, q) -} - -func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) { - s, args, err := BuildQuerySql[T](q) - if err != nil { - return - } - err = db.Get(ctx, &r, s, args...) + r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l) return } diff --git a/model/querycondition_test.go b/model/querycondition_test.go index d2ccbde..8b53790 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -1,571 +1 @@ package model - -import ( - "context" - "database/sql" - "fmt" - "github.com/fthvgb1/wp-go/helper/number" - "github.com/fthvgb1/wp-go/helper/slice" - "reflect" - "strconv" - "strings" - "testing" -) - -func TestFinds(t *testing.T) { - type args struct { - ctx context.Context - q *QueryCondition - } - type testCase[T Model] struct { - name string - args args - wantR []T - wantErr bool - } - tests := []testCase[post]{ - { - name: "t1", - args: args{ - ctx: context.Background(), - q: Conditions( - Where(SqlBuilder{ - {"post_status", "publish"}, {"ID", "in", ""}}, - ), - Order(SqlBuilder{{"ID", "desc"}}), - Offset(10), - Limit(10), - In([][]any{slice.ToAnySlice(number.Range(1, 1000, 1))}...), - ), - }, - wantR: func() []post { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, 1000, 1), strconv.Itoa), ",")+") order by ID desc limit 10 offset 10 ") - if err != nil { - panic(err) - } - return r - }(), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := Finds[post](tt.args.ctx, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("Findx() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("Findx() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -func TestChunkFind(t *testing.T) { - type args struct { - ctx context.Context - perLimit int - q *QueryCondition - } - type testCase[T Model] struct { - name string - args args - wantR []T - wantErr bool - } - n := 500 - tests := []testCase[post]{ - { - name: "in,orderBy", - args: args{ - ctx: ctx, - q: Conditions( - Where(SqlBuilder{{ - "post_status", "publish", - }, {"ID", "in", ""}}), - Order(SqlBuilder{{"ID", "desc"}}), - In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...), - ), - perLimit: 20, - }, - wantR: func() []post { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc") - if err != nil { - panic(err) - } - return r - }(), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := ChunkFind[post](tt.args.ctx, tt.args.perLimit, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("ChunkFind() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("ChunkFind() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -func TestChunk(t *testing.T) { - type args[T Model, R any] struct { - ctx context.Context - perLimit int - fn func(rows T) (R, bool) - q *QueryCondition - } - type testCase[T Model, R any] struct { - name string - args args[T, R] - wantR []R - wantErr bool - } - n := 500 - tests := []testCase[post, uint64]{ - { - name: "t1", - args: args[post, uint64]{ - ctx: ctx, - perLimit: 20, - fn: func(t post) (uint64, bool) { - if t.Id > 300 { - return t.Id, true - } - return 0, false - }, - q: Conditions( - Where(SqlBuilder{{ - "post_status", "publish", - }, {"ID", "in", ""}}), - Order(SqlBuilder{{"ID", "desc"}}), - In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...), - ), - }, - wantR: func() []uint64 { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc") - if err != nil { - panic(err) - } - return slice.FilterAndMap(r, func(t post) (uint64, bool) { - if t.Id <= 300 { - return 0, false - } - return t.Id, true - }) - }(), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := Chunk[post](tt.args.ctx, tt.args.perLimit, tt.args.fn, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("Chunk() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("Chunk() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -func TestPagination(t *testing.T) { - type args struct { - ctx context.Context - q *QueryCondition - } - type testCase[T Model] struct { - name string - args args - want []T - want1 int - wantErr bool - } - tests := []testCase[post]{ - { - name: "t1", - args: args{ - ctx: ctx, - q: Conditions( - Where(SqlBuilder{ - {"ID", "in", ""}, - }), - Page(1), - Limit(5), - In([][]any{slice.ToAnySlice(number.Range(431, 440, 1))}...), - ), - }, - want: func() (r []post) { - r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...) - if err != nil && err != sql.ErrNoRows { - panic(err) - } else if err == sql.ErrNoRows { - err = nil - } - return - }(), - want1: 10, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1, err := Pagination[post](tt.args.ctx, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("Pagination() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Pagination() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("Pagination() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} - -func TestColumn(t *testing.T) { - type args[V Model, T any] struct { - ctx context.Context - fn func(V) (T, bool) - q *QueryCondition - } - type testCase[V Model, T any] struct { - name string - args args[V, T] - wantR []T - wantErr bool - } - tests := []testCase[post, uint64]{ - { - name: "t1", - args: args[post, uint64]{ - ctx: ctx, - fn: func(t post) (uint64, bool) { - return t.Id, true - }, - q: Conditions( - Where(SqlBuilder{ - {"ID", "<", "200", "int"}, - }), - ), - }, - wantR: []uint64{63, 64, 190, 193}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := Column[post](tt.args.ctx, tt.args.fn, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("Column() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("Column() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -type options struct { - OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"` - OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"` - OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"` - Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"` -} - -func (w options) PrimaryKey() string { - return "option_id" -} - -func (w options) Table() string { - return "wp_options" -} - -func Test_getField(t *testing.T) { - { - name := "string" - db := glob - field := "option_value" - q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}})) - wantR := "记录并见证自己的成长" - wantErr := false - t.Run(name, func(t *testing.T) { - gotR, err := getField[options](db, ctx, field, q) - if (err != nil) != wantErr { - t.Errorf("getField() error = %v, wantErr %v", err, wantErr) - return - } - if !reflect.DeepEqual(gotR, wantR) { - t.Errorf("getField() gotR = %v, want %v", gotR, wantR) - } - }) - } - - { - name := "t2" - db := glob - field := "option_id" - q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}})) - wantR := "3" - wantErr := false - t.Run(name, func(t *testing.T) { - gotR, err := getField[options](db, ctx, field, q) - if (err != nil) != wantErr { - t.Errorf("getField() error = %v, wantErr %v", err, wantErr) - return - } - if !reflect.DeepEqual(gotR, wantR) { - t.Errorf("getField() gotR = %v, want %v", gotR, wantR) - } - }) - } - { - name := "count(*)" - db := glob - field := "count(*)" - q := Conditions() - wantR := "386" - wantErr := false - t.Run(name, func(t *testing.T) { - gotR, err := getField[options](db, ctx, field, q) - if (err != nil) != wantErr { - t.Errorf("getField() error = %v, wantErr %v", err, wantErr) - return - } - if !reflect.DeepEqual(gotR, wantR) { - t.Errorf("getField() gotR = %v, want %v", gotR, wantR) - } - }) - } -} - -func Test_getToStringMap(t *testing.T) { - type args struct { - db dbQuery - ctx context.Context - q *QueryCondition - } - tests := []struct { - name string - args args - wantR map[string]string - wantErr bool - }{ - { - name: "t1", - args: args{ - db: glob, - ctx: ctx, - q: Conditions(Where(SqlBuilder{{"option_name", "users_can_register"}})), - }, - wantR: map[string]string{ - "option_id": "5", - "option_value": "0", - "option_name": "users_can_register", - "autoload": "yes", - }, - }, - { - name: "t2", - args: args{ - db: glob, - ctx: ctx, - q: Conditions( - Where(SqlBuilder{{"option_name", "users_can_register"}}), - Fields("option_id id"), - ), - }, - wantR: map[string]string{ - "id": "5", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := getToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("getToStringMap() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("getToStringMap() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -func Test_findToStringMap(t *testing.T) { - type args struct { - db dbQuery - ctx context.Context - q *QueryCondition - } - tests := []struct { - name string - args args - wantR []map[string]string - wantErr bool - }{ - { - name: "t1", - args: args{ - db: glob, - ctx: ctx, - q: Conditions(Where(SqlBuilder{{"option_id", "5"}})), - }, - wantR: []map[string]string{{ - "option_id": "5", - "option_value": "0", - "option_name": "users_can_register", - "autoload": "yes", - }}, - wantErr: false, - }, - { - name: "t2", - args: args{ - db: glob, - ctx: ctx, - q: Conditions( - Where(SqlBuilder{{"option_id", "5"}}), - Fields("option_value,option_name"), - ), - }, - wantR: []map[string]string{{ - "option_value": "0", - "option_name": "users_can_register", - }}, - wantErr: false, - }, - { - name: "t3", - args: args{ - db: glob, - ctx: ctx, - q: Conditions( - Where(SqlBuilder{{"option_id", "5"}}), - Fields("option_value v,option_name k"), - ), - }, - wantR: []map[string]string{{ - "v": "0", - "k": "users_can_register", - }}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := findToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("findToStringMap() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("findToStringMap() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - -func Test_findScanner(t *testing.T) { - type args[T Model] struct { - db dbQuery - ctx context.Context - fn func(T) - q *QueryCondition - } - type testCase[T Model] struct { - name string - args args[T] - wantErr bool - } - tests := []testCase[options]{ - { - name: "t1", - args: args[options]{glob, ctx, func(t options) { - fmt.Println(t) - }, Conditions(Where(SqlBuilder{{"option_id", "<", "10", "int"}}))}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := findScanner[options](tt.args.db, tt.args.ctx, tt.args.fn, tt.args.q); (err != nil) != tt.wantErr { - t.Errorf("findScanner() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func BenchmarkScannerXX(b *testing.B) { - for i := 0; i < b.N; i++ { - err := findScanner[options](glob, ctx, func(t options) { - _ = t - //fmt.Println(t) - }, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}}))) - if err != nil { - panic(err) - } - } -} - -func BenchmarkFindsXX(b *testing.B) { - for i := 0; i < b.N; i++ { - r, err := finds[options](glob, ctx, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}}))) - if err != nil { - panic(err) - } - for _, o := range r { - _ = o - //fmt.Println(o) - } - } -} - -func Test_gets(t *testing.T) { - type args struct { - db dbQuery - ctx context.Context - q *QueryCondition - } - type testCase[T Model] struct { - name string - args args - wantR T - wantErr bool - } - tests := []testCase[options]{ - { - name: "t1", - args: args{ - db: glob, - ctx: ctx, - q: Conditions(Where(SqlBuilder{{"option_name", "blogname"}})), - }, - wantR: options{3, "blogname", "记录并见证自己的成长", "yes"}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := gets[options](tt.args.db, tt.args.ctx, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("gets() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("gets() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} diff --git a/model/sqxquery.go b/model/sqxquery.go index 22c6784..4a7a1fa 100644 --- a/model/sqxquery.go +++ b/model/sqxquery.go @@ -9,14 +9,14 @@ import ( "strings" ) -type SqlxQuery struct { +type SqlxQuery[T any] struct { sqlx *sqlx.DB - UniversalDb + UniversalDb[T] } -func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery { +func NewSqlxQuery[T any](sqlx *sqlx.DB, u UniversalDb[T]) *SqlxQuery[T] { - s := &SqlxQuery{sqlx: sqlx, UniversalDb: u} + s := &SqlxQuery[T]{sqlx: sqlx, UniversalDb: u} if u.selects == nil { s.UniversalDb.selects = s.Selects } @@ -26,52 +26,56 @@ func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery { return s } -func SetSelect(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) { +func SetSelect[T any](db *SqlxQuery[T], fn QuerySelect[T]) { db.selects = fn } -func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) { +func SetGet[T any](db *SqlxQuery[T], fn QueryGet[T]) { db.gets = fn } -func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error { +func (s *SqlxQuery[T]) Selects(ctx context.Context, sql string, params ...any) (r []T, err error) { v := ctx.Value("handle=>") if v != nil { vv, ok := v.(string) if ok && vv != "" { switch vv { case "string": - return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) + //return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) case "scanner": fn := ctx.Value("fn") - return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any))) + return nil, Scanner[T](s.sqlx, sql, params...)(fn.(func(T))) } } } - return r.sqlx.Select(dest, sql, params...) + //var a T + err = s.sqlx.Select(&r, sql, params...) + return } -func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { +func (s *SqlxQuery[T]) Gets(ctx context.Context, sql string, params ...any) (r T, err error) { v := ctx.Value("handle=>") if v != nil { vv, ok := v.(string) if ok && vv != "" { switch vv { case "string": - return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) + //return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) } } } - return r.sqlx.Get(dest, sql, params...) + err = s.sqlx.Get(&r, sql, params...) + return } -func Scanner[T any](db *sqlx.DB, v T, s string, params ...any) func(func(T)) error { +func Scanner[T any](db *sqlx.DB, s string, params ...any) func(func(T)) error { + var v T return func(fn func(T)) error { rows, err := db.Queryx(s, params...) if err != nil { return err } for rows.Next() { - err = rows.StructScan(v) + err = rows.StructScan(&v) if err != nil { return err } diff --git a/model/universal.go b/model/universal.go index 866c50a..e838b30 100644 --- a/model/universal.go +++ b/model/universal.go @@ -2,19 +2,15 @@ package model import "context" -type UniversalDb struct { - selects QueryFn - gets QueryFn +type UniversalDb[T any] struct { + selects QuerySelect[T] + gets QueryGet[T] } -func NewUniversalDb(selects QueryFn, gets QueryFn) UniversalDb { - return UniversalDb{selects: selects, gets: gets} +func (u *UniversalDb[T]) Select(ctx context.Context, s string, a ...any) ([]T, error) { + return u.selects(ctx, s, a...) } -func (u UniversalDb) Select(ctx context.Context, a any, s string, args ...any) error { - return u.selects(ctx, a, s, args...) -} - -func (u UniversalDb) Get(ctx context.Context, a any, s string, args ...any) error { - return u.gets(ctx, a, s, args...) +func (u *UniversalDb[T]) Get(ctx context.Context, s string, a ...any) (T, error) { + return u.gets(ctx, s, a...) }