Compare commits

...

3 Commits

Author SHA1 Message Date
xing
9ee402af4b 优化 where解析 2023-02-22 20:07:57 +08:00
xing
00f788fad5 scanner的性能比sqlx的select好些的样子 2023-02-22 19:12:26 +08:00
xing
1ecfa19fd4 dbquery泛型化查询,似乎性能和之前也差不多的样子 2023-02-22 16:53:53 +08:00
9 changed files with 183 additions and 1524 deletions

View File

@ -2,6 +2,7 @@ package model
type QueryCondition struct { type QueryCondition struct {
where ParseWhere where ParseWhere
from string
fields string fields string
group string group string
order SqlBuilder order SqlBuilder
@ -37,6 +38,12 @@ func Fields(fields string) Condition {
} }
} }
func From(from string) Condition {
return func(c *QueryCondition) {
c.from = from
}
}
func Group(group string) Condition { func Group(group string) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.group = group c.group = group

View File

@ -5,13 +5,14 @@ import (
) )
var _ ParseWhere = SqlBuilder{} var _ ParseWhere = SqlBuilder{}
var globalBb dbQuery var globalBb dbQuery[Model]
func InitDB(db dbQuery) { func InitDB(db dbQuery[Model]) {
globalBb = db globalBb = db
} }
type QueryFn func(context.Context, any, string, ...any) error type QuerySelect[T any] func(context.Context, string, ...any) ([]T, error)
type QueryGet[T any] func(context.Context, string, ...any) (T, error)
type Model interface { type Model interface {
PrimaryKey() string PrimaryKey() string
@ -22,9 +23,9 @@ type ParseWhere interface {
ParseWhere(*[][]any) (string, []any, error) ParseWhere(*[][]any) (string, []any, error)
} }
type dbQuery interface { type dbQuery[T any] interface {
Select(context.Context, any, string, ...any) error Select(context.Context, string, ...any) ([]T, error)
Get(context.Context, any, string, ...any) error Get(context.Context, string, ...any) (T, error)
} }
type SqlBuilder [][]string type SqlBuilder [][]string

View File

@ -1,40 +1,28 @@
package model package model
import ( import (
"fmt"
"github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/helper/slice"
str "github.com/fthvgb1/wp-go/helper/strings"
"strconv" "strconv"
"strings" "strings"
) )
func (w SqlBuilder) parseField(ss []string, s *strings.Builder) { func (w SqlBuilder) parseField(ss []string, s *str.Builder) {
if strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") { if strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") {
s.WriteString("`")
sx := strings.Split(ss[0], ".") sx := strings.Split(ss[0], ".")
s.WriteString(sx[0]) s.Sprintf("`%s`.`%s`", sx[0], sx[1])
s.WriteString("`.`")
s.WriteString(sx[1])
s.WriteString("`")
} else if !strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") { } else if !strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") {
s.WriteString("`") s.Sprintf("`%s`", ss[0])
s.WriteString(ss[0])
s.WriteString("`")
} else { } else {
s.WriteString(ss[0]) s.WriteString(ss[0])
} }
} }
func (w SqlBuilder) parseIn(ss []string, s *strings.Builder, c *int, args *[]any, in *[][]any) (t bool) { func (w SqlBuilder) parseIn(ss []string, s *str.Builder, c *int, args *[]any, in *[][]any) (t bool) {
if slice.IsContained(ss[1], []string{"in", "not in"}) && len(*in) > 0 { if slice.IsContained(ss[1], []string{"in", "not in"}) && len(*in) > 0 {
s.WriteString(" (") sss := strings.Repeat("?,", len((*in)[*c]))
for _, p := range (*in)[*c] { s.Sprintf("(%s)", strings.TrimRight(sss, ","))
s.WriteString("?,") *args = append(*args, (*in)[*c]...)
*args = append(*args, p)
}
sx := s.String()
s.Reset()
s.WriteString(strings.TrimRight(sx, ","))
s.WriteString(")")
*c++ *c++
t = true t = true
} }
@ -78,18 +66,18 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error {
// //
// {{"and","field","=","num1","int","or","field","=","num2","int"}} => where (field = num1 or field = num2') // {{"and","field","=","num1","int","or","field","=","num2","int"}} => where (field = num1 or field = num2')
func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) { func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
var s strings.Builder var s = str.NewBuilder()
args := make([]any, 0, len(w)) args := make([]any, 0, len(w))
c := 0 c := 0
for _, ss := range w { for _, ss := range w {
if len(ss) == 2 { if len(ss) == 2 {
w.parseField(ss, &s) w.parseField(ss, s)
s.WriteString("=? and ") s.WriteString("=? and ")
args = append(args, ss[1]) args = append(args, ss[1])
} else if len(ss) >= 3 && len(ss) < 5 { } else if len(ss) >= 3 && len(ss) < 5 {
w.parseField(ss, &s) w.parseField(ss, s)
s.WriteString(ss[1]) s.WriteString(ss[1])
if w.parseIn(ss, &s, &c, &args, in) { if w.parseIn(ss, s, &c, &args, in) {
s.WriteString(" and ") s.WriteString(" and ")
continue continue
} }
@ -107,15 +95,14 @@ func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
if strings.Contains(st, "and ") && ss[start] == "or" { if strings.Contains(st, "and ") && ss[start] == "or" {
st = strings.TrimRight(st, "and ") st = strings.TrimRight(st, "and ")
s.Reset() s.Reset()
s.WriteString(st) s.Sprintf("%s %s ", st, ss[start])
s.WriteString(fmt.Sprintf(" %s ", ss[start]))
} }
if i == 0 { if i == 0 {
s.WriteString("( ") s.WriteString("( ")
} }
w.parseField(ss[start+1:end], &s) w.parseField(ss[start+1:end], s)
s.WriteString(ss[start+2]) s.WriteString(ss[start+2])
if w.parseIn(ss[start+1:end], &s, &c, &args, in) { if w.parseIn(ss[start+1:end], s, &c, &args, in) {
s.WriteString(" and ") s.WriteString(" and ")
continue continue
} }
@ -128,15 +115,13 @@ func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
st := s.String() st := s.String()
st = strings.TrimRight(st, "and ") st = strings.TrimRight(st, "and ")
s.Reset() s.Reset()
s.WriteString(st) s.Sprintf("%s ) and ", st)
s.WriteString(") and ")
} }
} }
ss := strings.TrimRight(s.String(), "and ") ss := strings.TrimRight(s.String(), "and ")
if ss != "" { if ss != "" {
s.Reset() s.Reset()
s.WriteString(" where ") s.Sprintf(" where %s", ss)
s.WriteString(ss)
ss = s.String() ss = s.String()
} }
if len(*in) > c { if len(*in) > c {

View File

@ -1,185 +1,25 @@
package model package model
import ( import "context"
"context"
"fmt"
"golang.org/x/exp/constraints"
"math/rand"
"strings"
)
func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { func finds[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) {
var rr T s, args, err := BuildQuerySql[T](q)
var w string
var args []any
if where != nil {
w, args, err = where.ParseWhere(&in)
if err != nil { if err != nil {
return r, total, err return nil, err
} }
} return db.Select(ctx, s, args...)
h := ""
if having != nil {
hh, arg, err := having.ParseWhere(&in)
if err != nil {
return r, total, err
}
args = append(args, arg...)
h = strings.Replace(hh, " where", " having", 1)
}
n := struct {
N int `db:"n" json:"n"`
}{}
groupBy := ""
if group != "" {
g := strings.Builder{}
g.WriteString(" group by ")
g.WriteString(group)
groupBy = g.String()
}
if having != nil {
tm := map[string]struct{}{}
for _, s := range strings.Split(group, ",") {
tm[s] = struct{}{}
}
for _, ss := range having {
if _, ok := tm[ss[0]]; !ok {
group = fmt.Sprintf("%s,%s", group, ss[0])
}
}
group = strings.Trim(group, ",")
}
j := join.parseJoin()
if group == "" {
tpx := "select count(*) n from %s %s %s limit 1"
sq := fmt.Sprintf(tpx, rr.Table(), j, w)
err = db.Get(ctx, &n, sq, args...)
} else {
tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s"
sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int()))
err = db.Get(ctx, &n, sq, args...)
}
if err != nil {
return
}
if n.N == 0 {
return
}
total = n.N
offset := 0
if page > 1 {
offset = (page - 1) * pageSize
}
if offset >= total {
return
}
tp := "select %s from %s %s %s %s %s %s limit %d,%d"
sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize)
err = db.Select(ctx, &r, sq, args...)
if err != nil {
return
}
return
} }
func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { func scanners[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) {
r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...) s, args, err := BuildQuerySql[T](q)
return
}
func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) {
var r T
sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey())
err := globalBb.Get(ctx, &r, sq, id)
if err != nil { if err != nil {
return r, err return nil, err
} }
return r, nil ctx = context.WithValue(ctx, "handle=>", "scanner")
}
func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) {
s, args, err := BuildQuerySql[T](&QueryCondition{
where: where,
fields: fields,
order: order,
in: in,
limit: 1,
})
if err != nil {
return
}
err = globalBb.Get(ctx, &r, s, args...)
return
}
func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) {
var r T
var w string
var args []any
var err error
if where != nil {
w, args, err = where.ParseWhere(&in)
if err != nil {
return r, err
}
}
tp := "select %s from %s %s order by %s desc limit 1"
sq := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey())
err = globalBb.Get(ctx, &r, sq, args...)
if err != nil {
return r, err
}
return r, nil
}
func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) {
s, args, err := BuildQuerySql[T](&QueryCondition{
where: where,
fields: fields,
in: in,
})
if err != nil {
return
}
err = globalBb.Select(ctx, &r, s, args...)
return r, nil
}
func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error) {
var r []T var r []T
var rr T ctx = context.WithValue(ctx, "fn", func(t T) {
sql = strings.Replace(sql, "{table}", rr.Table(), -1) r = append(r, t)
err := globalBb.Select(ctx, &r, sql, params...) })
if err != nil { _, err = db.Select(ctx, s, args...)
return r, err return r, err
}
return r, nil
}
func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) {
q := QueryCondition{
where: where,
fields: fields,
group: group,
order: order,
join: join,
having: having,
limit: limit,
in: in,
}
s, args, err := BuildQuerySql[T](&q)
if err != nil {
return
}
err = globalBb.Select(ctx, &r, s, args...)
return
}
func Get[T Model](ctx context.Context, sql string, params ...any) (r T, err error) {
sql = strings.Replace(sql, "{table}", r.Table(), -1)
err = globalBb.Get(ctx, &r, sql, params...)
return
} }

View File

@ -2,517 +2,159 @@ package model
import ( import (
"context" "context"
"database/sql" "github.com/fthvgb1/wp-go/safety"
"github.com/fthvgb1/wp-go/helper/number"
"github.com/fthvgb1/wp-go/helper/slice"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
) )
type post struct {
Id uint64 `gorm:"column:ID" db:"ID" json:"ID" form:"ID"`
PostAuthor uint64 `gorm:"column:post_author" db:"post_author" json:"post_author" form:"post_author"`
PostDate time.Time `gorm:"column:post_date" db:"post_date" json:"post_date" form:"post_date"`
PostDateGmt time.Time `gorm:"column:post_date_gmt" db:"post_date_gmt" json:"post_date_gmt" form:"post_date_gmt"`
PostContent string `gorm:"column:post_content" db:"post_content" json:"post_content" form:"post_content"`
PostTitle string `gorm:"column:post_title" db:"post_title" json:"post_title" form:"post_title"`
PostExcerpt string `gorm:"column:post_excerpt" db:"post_excerpt" json:"post_excerpt" form:"post_excerpt"`
PostStatus string `gorm:"column:post_status" db:"post_status" json:"post_status" form:"post_status"`
CommentStatus string `gorm:"column:comment_status" db:"comment_status" json:"comment_status" form:"comment_status"`
PingStatus string `gorm:"column:ping_status" db:"ping_status" json:"ping_status" form:"ping_status"`
PostPassword string `gorm:"column:post_password" db:"post_password" json:"post_password" form:"post_password"`
PostName string `gorm:"column:post_name" db:"post_name" json:"post_name" form:"post_name"`
ToPing string `gorm:"column:to_ping" db:"to_ping" json:"to_ping" form:"to_ping"`
Pinged string `gorm:"column:pinged" db:"pinged" json:"pinged" form:"pinged"`
PostModified time.Time `gorm:"column:post_modified" db:"post_modified" json:"post_modified" form:"post_modified"`
PostModifiedGmt time.Time `gorm:"column:post_modified_gmt" db:"post_modified_gmt" json:"post_modified_gmt" form:"post_modified_gmt"`
PostContentFiltered string `gorm:"column:post_content_filtered" db:"post_content_filtered" json:"post_content_filtered" form:"post_content_filtered"`
PostParent uint64 `gorm:"column:post_parent" db:"post_parent" json:"post_parent" form:"post_parent"`
Guid string `gorm:"column:guid" db:"guid" json:"guid" form:"guid"`
MenuOrder int `gorm:"column:menu_order" db:"menu_order" json:"menu_order" form:"menu_order"`
PostType string `gorm:"column:post_type" db:"post_type" json:"post_type" form:"post_type"`
PostMimeType string `gorm:"column:post_mime_type" db:"post_mime_type" json:"post_mime_type" form:"post_mime_type"`
CommentCount int64 `gorm:"column:comment_count" db:"comment_count" json:"comment_count" form:"comment_count"`
}
type user struct {
Id uint64 `gorm:"column:ID" db:"ID" json:"ID"`
UserLogin string `gorm:"column:user_login" db:"user_login" json:"user_login"`
UserPass string `gorm:"column:user_pass" db:"user_pass" json:"user_pass"`
UserNicename string `gorm:"column:user_nicename" db:"user_nicename" json:"user_nicename"`
UserEmail string `gorm:"column:user_email" db:"user_email" json:"user_email"`
UserUrl string `gorm:"column:user_url" db:"user_url" json:"user_url"`
UserRegistered time.Time `gorm:"column:user_registered" db:"user_registered" json:"user_registered"`
UserActivationKey string `gorm:"column:user_activation_key" db:"user_activation_key" json:"user_activation_key"`
UserStatus int `gorm:"column:user_status" db:"user_status" json:"user_status"`
DisplayName string `gorm:"column:display_name" db:"display_name" json:"display_name"`
}
type termTaxonomy struct {
TermTaxonomyId uint64 `gorm:"column:term_taxonomy_id" db:"term_taxonomy_id" json:"term_taxonomy_id" form:"term_taxonomy_id"`
TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"`
Taxonomy string `gorm:"column:taxonomy" db:"taxonomy" json:"taxonomy" form:"taxonomy"`
Description string `gorm:"column:description" db:"description" json:"description" form:"description"`
Parent uint64 `gorm:"column:parent" db:"parent" json:"parent" form:"parent"`
Count int64 `gorm:"column:count" db:"count" json:"count" form:"count"`
}
type terms struct {
TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"`
Name string `gorm:"column:name" db:"name" json:"name" form:"name"`
Slug string `gorm:"column:slug" db:"slug" json:"slug" form:"slug"`
TermGroup int64 `gorm:"column:term_group" db:"term_group" json:"term_group" form:"term_group"`
}
func (t terms) PrimaryKey() string {
return "term_id"
}
func (t terms) Table() string {
return "wp_terms"
}
func (w termTaxonomy) PrimaryKey() string {
return "term_taxonomy_id"
}
func (w termTaxonomy) Table() string {
return "wp_term_taxonomy"
}
func (u user) Table() string {
return "wp_users"
}
func (u user) PrimaryKey() string {
return "ID"
}
func (p post) PrimaryKey() string {
return "ID"
}
func (p post) Table() string {
return "wp_posts"
}
var ctx = context.Background() var ctx = context.Background()
var glob *SqlxQuery var glob = safety.NewMap[string, dbQuery[Model]]()
var dbMap = sync.Map{}
var sq *sqlx.DB
func anyDb[T Model]() *SqlxQuery[T] {
var a T
db, ok := dbMap.Load(a.Table())
if ok {
return db.(*SqlxQuery[T])
}
dbb := NewSqlxQuery[T](sq, UniversalDb[T]{nil, nil})
dbMap.Store(a.Table(), dbb)
return dbb
}
func init() { func init() {
db, err := sqlx.Open("mysql", "root:root@tcp(192.168.66.47:3306)/wordpress?charset=utf8mb4&parseTime=True&loc=Local") db, err := sqlx.Open("mysql", "root:root@tcp(192.168.66.47:3306)/wordpress?charset=utf8mb4&parseTime=True&loc=Local")
if err != nil { if err != nil {
panic(err) panic(err)
} }
glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) sq = db
InitDB(glob) //glob = NewSqlxQuery(db, NewUniversalDb(nil, nil))
}
func TestFind(t *testing.T) {
type args struct {
where ParseWhere
fields string
group string
order SqlBuilder
join SqlBuilder
having SqlBuilder
limit int
in [][]any
}
type posts struct {
post
N int `db:"n"`
}
tests := []struct {
name string
args args
wantR []posts
wantErr bool
}{
{
name: "in,orderBy",
args: args{
where: SqlBuilder{{
"post_status", "publish",
}, {"ID", "in", ""}},
fields: "*",
group: "",
order: SqlBuilder{{"ID", "desc"}},
join: nil,
having: nil,
limit: 0,
in: [][]any{{1, 2, 3, 4}},
},
wantR: func() []posts {
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where post_status='publish' and ID in (1,2,3,4) order by ID desc")
if err != nil {
panic(err)
}
return r
}(),
wantErr: false,
},
{
name: "or",
args: args{
where: SqlBuilder{{
"and", "ID", "=", "1", "int",
}, {"or", "ID", "=", "2", "int"}},
fields: "*",
group: "",
order: nil,
join: nil,
having: nil,
limit: 0,
in: nil,
},
wantR: func() []posts {
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID=1 or ID=2)")
if err != nil {
panic(err)
}
return r
}(),
},
{
name: "group,having",
args: args{
where: SqlBuilder{
{"ID", "<", "1000", "int"},
},
fields: "post_status,count(*) n",
group: "post_status",
order: nil,
join: nil,
having: SqlBuilder{
{"n", ">", "1"},
},
limit: 0,
in: nil,
},
wantR: func() []posts {
r, err := Select[posts](ctx, "select post_status,count(*) n from "+post{}.Table()+" where ID<1000 group by post_status having n>1")
if err != nil {
panic(err)
}
return r
}(),
},
{
name: "or、多个in",
args: args{
where: SqlBuilder{
{"and", "ID", "in", "", "", "or", "ID", "in", "", ""},
{"or", "post_status", "=", "publish", "", "and", "post_status", "=", "closed", ""},
},
fields: "*",
group: "",
order: nil,
join: nil,
having: nil,
limit: 0,
in: [][]any{{1, 2, 3}, {4, 5, 6}},
},
wantR: func() []posts {
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID in (1,2,3) or ID in (4,5,6)) or (post_status='publish' and post_status='closed')")
if err != nil {
panic(err)
}
return r
}(),
},
{
name: "all",
args: args{
where: SqlBuilder{
{"b.user_login", "in", ""},
{"and", "a.post_type", "=", "post", "", "or", "a.post_type", "=", "page", ""},
{"a.comment_count", ">", "0", "int"},
{"a.post_status", "publish"},
{"e.name", "in", ""},
{"d.taxonomy", "category"},
},
fields: "post_author,count(*) n",
group: "a.post_author",
order: SqlBuilder{{"n", "desc"}},
join: SqlBuilder{
{"a", "left join", user{}.Table() + " b", "a.post_author=b.ID"},
{"left join", "wp_term_relationships c", "a.Id=c.object_id"},
{"left join", termTaxonomy{}.Table() + " d", "c.term_taxonomy_id=d.term_taxonomy_id"},
{"left join", terms{}.Table() + " e", "d.term_id=e.term_id"},
},
having: SqlBuilder{{"n", ">", "0", "int"}},
limit: 10,
in: [][]any{{"test", "test2"}, {"web", "golang", "php"}},
},
wantR: func() []posts {
r, err := Select[posts](ctx, "select post_author,count(*) n from wp_posts a left join wp_users b on a.post_author=b.ID left join wp_term_relationships c on a.Id=c.object_id left join wp_term_taxonomy d on c.term_taxonomy_id=d.term_taxonomy_id left join wp_terms e on d.term_id=e.term_id where b.user_login in ('test','test2') and b.user_status=0 and (a.post_type='post' or a.post_type='page') and a.comment_count>0 and a.post_status='publish' and e.name in ('web','golang','php') and d.taxonomy='category' group by post_author having n > 0 order by n desc limit 10")
if err != nil {
panic(err)
}
return r
}(),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := Find[posts](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.order, tt.args.join, tt.args.having, tt.args.limit, tt.args.in...)
if (err != nil) != tt.wantErr {
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("Find() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
} }
func TestFindOneById(t *testing.T) { func Test_selects(t *testing.T) {
type args struct { type args[T Model] struct {
id int db dbQuery[T]
ctx context.Context
q *QueryCondition
} }
type testCase[T Model] struct {
tests := []struct {
name string name string
args args args args[T]
want post want []T
wantErr bool wantErr bool
}{ }
tests := []testCase[options]{
{ {
name: "t1", name: "t1",
args: args{ args: args[options]{
1, anyDb[options](),
ctx,
Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
}, },
want: func() post {
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where ID=?", 1)
if err != nil && err != sql.ErrNoRows {
panic(err)
} else if err == sql.ErrNoRows {
err = nil
}
return r
}(),
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := FindOneById[post](ctx, tt.args.id) got, err := finds[options](tt.args.db, tt.args.ctx, tt.args.q)
if err == sql.ErrNoRows {
err = nil
}
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("FindOneById() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("finds() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("FindOneById() got = %v, want %v", got, tt.want) t.Errorf("finds() got = %v, want %v", got, tt.want)
} }
}) })
} }
} }
func TestFirstOne(t *testing.T) { func BenchmarkSelectXX(b *testing.B) {
type args struct { for i := 0; i < b.N; i++ {
where ParseWhere _, err := finds[options](anyDb[options](), ctx, Conditions(
fields string Where(SqlBuilder{{"option_id", "<", "50", "int"}}),
order SqlBuilder //In(slice.ToAnySlice(number.Range[uint64](1, 50, 1))),
in [][]any ))
}
tests := []struct {
name string
args args
want post
wantErr bool
}{
{
name: "t1",
args: args{
where: SqlBuilder{{"post_status", "publish"}},
fields: "*",
order: SqlBuilder{{"ID", "desc"}},
in: nil,
},
wantErr: false,
want: func() post {
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by ID desc limit 1")
if err != nil && err != sql.ErrNoRows {
panic(err)
} else if err == sql.ErrNoRows {
err = nil
}
return r
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...)
if (err != nil) != tt.wantErr {
t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("FirstOne() got = %v, want %v", got, tt.want)
}
})
}
}
func TestLastOne(t *testing.T) {
type args struct {
where ParseWhere
fields string
in [][]any
}
tests := []struct {
name string
args args
want post
wantErr bool
}{
{
name: "t1",
args: args{
where: SqlBuilder{{
"post_status", "publish",
}},
fields: "*",
in: nil,
},
want: func() post {
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by "+post{}.PrimaryKey()+" desc limit 1")
if err != nil { if err != nil {
panic(err) panic(err)
} }
return r
}(),
},
} }
for _, tt := range tests { }
t.Run(tt.name, func(t *testing.T) { func BenchmarkScannerXX(b *testing.B) {
got, err := LastOne[post](ctx, tt.args.where, tt.args.fields, tt.args.in...) for i := 0; i < b.N; i++ {
if (err != nil) != tt.wantErr { _, err := scanners[options](anyDb[options](), ctx, Conditions(
t.Errorf("LastOne() error = %v, wantErr %v", err, tt.wantErr) Where(SqlBuilder{{"option_id", "<", "50", "int"}}),
return //In(slice.ToAnySlice(number.Range[uint64](1, 50, 1))),
))
if err != nil {
panic(err)
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("LastOne() got = %v, want %v", got, tt.want)
}
})
} }
} }
func TestSimpleFind(t *testing.T) { func BenchmarkSqlXQueryXX(b *testing.B) {
type args struct { var r []options
where ParseWhere /*var r []options
fields string x := number.Range[uint64](1, 50, 1)
in [][]any j := strings.TrimRight(strings.Repeat("?,", 50), ",")
} s := str.NewBuilder()
tests := []struct { s.Sprintf("select * from wp_options where option_id in (%s)", j)
name string ss := s.String()
args args a := slice.ToAnySlice(x)*/
want []post ss := "select * from wp_options where option_id < ?"
wantErr bool for i := 0; i < b.N; i++ {
}{ err := sq.Select(&r, ss, 50)
{ if err != nil {
name: "t1",
args: args{
where: SqlBuilder{
{"ID", "in", ""},
},
fields: "*",
in: [][]any{{1, 2}},
},
want: func() (r []post) {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?)", 1, 2)
if err != nil && err != sql.ErrNoRows {
panic(err) panic(err)
} else if err == sql.ErrNoRows {
err = nil
} }
return
}(),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := SimpleFind[post](ctx, tt.args.where, tt.args.fields, tt.args.in...)
if (err != nil) != tt.wantErr {
t.Errorf("SimpleFind() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SimpleFind() got = %v, want %v", got, tt.want)
}
})
} }
} }
func TestSimplePagination(t *testing.T) { type options struct {
type args struct { OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"`
where ParseWhere OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"`
fields string OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"`
group string Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"`
page int }
pageSize int
order SqlBuilder func (w options) PrimaryKey() string {
join SqlBuilder return "option_id"
having SqlBuilder }
in [][]any
func (w options) Table() string {
return "wp_options"
}
func Test_scanners(t *testing.T) {
type args[T Model] struct {
db dbQuery[T]
ctx context.Context
q *QueryCondition
} }
tests := []struct { type testCase[T Model] struct {
name string name string
args args args args[T]
wantR []post
wantTotal int
wantErr bool wantErr bool
}{ }
tests := []testCase[options]{
{ {
name: "t1", name: "t1",
args: args{ args: args[options]{
where: SqlBuilder{ anyDb[options](),
{"ID", "in", ""}, ctx,
Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
}, },
fields: "*",
group: "",
page: 1,
pageSize: 5,
order: nil,
join: nil,
having: nil,
in: [][]any{slice.ToAnySlice(number.Range(431, 440, 1))},
},
wantR: func() (r []post) {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...)
if err != nil && err != sql.ErrNoRows {
panic(err)
} else if err == sql.ErrNoRows {
err = nil
}
return
}(),
wantTotal: 10,
wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotR, gotTotal, err := SimplePagination[post](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.page, tt.args.pageSize, tt.args.order, tt.args.join, tt.args.having, tt.args.in...) if _, err := scanners[options](tt.args.db, tt.args.ctx, tt.args.q); (err != nil) != tt.wantErr {
if (err != nil) != tt.wantErr { t.Errorf("scanners() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("SimplePagination() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("SimplePagination() gotR = %v, want %v", gotR, tt.wantR)
}
if gotTotal != tt.wantTotal {
t.Errorf("SimplePagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal)
} }
}) })
} }

View File

@ -1,222 +1,10 @@
package model package model
import ( import (
"context"
"database/sql"
"errors"
"fmt" "fmt"
"github.com/fthvgb1/wp-go/helper/slice"
"strings" "strings"
) )
// Finds 比 Find 多一个offset
//
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](globalBb, ctx, q)
return
}
// FindFromDB 同 Finds 使用指定 db 查询
//
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](db, ctx, q)
return
}
func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
sq, args, err := BuildQuerySql[T](q)
if err != nil {
return
}
err = db.Select(ctx, &r, sq, args...)
return
}
func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
i := 1
var rr []T
var total int
var offset int
for {
if 1 == i {
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
} else {
q.offset = offset
q.limit = perLimit
rr, err = finds[T](db, ctx, q)
}
offset += perLimit
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
return
}
r = append(r, rr...)
if len(r) >= total {
break
}
i++
}
return
}
// ChunkFind 分片查询并直接返回所有结果
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](globalBb, ctx, perLimit, q)
return
}
// ChunkFindFromDB 同 ChunkFind
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](db, ctx, perLimit, q)
return
}
// Chunk 分片查询并函数过虑返回新类型的切片
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(globalBb, ctx, perLimit, fn, q)
return
}
// ChunkFromDB 同 Chunk
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(db, ctx, perLimit, fn, q)
return
}
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
i := 1
var rr []T
var count int
var total int
var offset int
for {
if 1 == i {
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
} else {
q.offset = offset
q.limit = perLimit
rr, err = finds[T](db, ctx, q)
}
offset += perLimit
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
return
}
for _, t := range rr {
v, ok := fn(t)
if ok {
r = append(r, v)
}
}
count += len(rr)
if count >= total {
break
}
i++
}
return
}
// Pagination 同 SimplePagination
//
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) {
return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
}
// PaginationFromDB 同 Pagination 方便多个db使用
//
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) {
return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
}
func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) {
return column[V, T](globalBb, ctx, fn, q)
}
func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
return column[V, T](db, ctx, fn, q)
}
func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
res, err := finds[V](db, ctx, q)
if err != nil {
return nil, err
}
r = slice.FilterAndMap(res, fn)
return
}
func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) {
r, err = getField[T](globalBb, ctx, field, q)
return
}
func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
if q.fields == "" || q.fields == "*" {
q.fields = field
}
res, err := getToStringMap[T](db, ctx, q)
if err != nil {
return
}
f := strings.Split(field, " ")
r, ok := res[f[len(f)-1]]
if !ok {
err = errors.New("not exists")
}
return
}
func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
return getField[T](db, ctx, field, q)
}
func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, "handle=>", "string")
err = db.Get(ctx, &r, rawSql, in...)
return
}
func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](globalBb, ctx, q)
return
}
func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, "handle=>", "string")
err = db.Select(ctx, &r, rawSql, in...)
return
}
func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](globalBb, ctx, q)
return
}
func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](db, ctx, q)
return
}
func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](db, ctx, q)
return
}
func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) { func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) {
var rr T var rr T
w := "" w := ""
@ -254,44 +42,10 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error)
if q.offset > 0 { if q.offset > 0 {
l = fmt.Sprintf(" %s offset %d", l, q.offset) l = fmt.Sprintf(" %s offset %d", l, q.offset)
} }
r = fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) table := rr.Table()
return if q.from != "" {
} table = q.from
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) {
s, args, err := BuildQuerySql[T](q)
if err != nil {
return
} }
ctx = context.WithValue(ctx, "handle=>", "scanner") r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l)
var v T
ctx = context.WithValue(ctx, "fn", func(v any) {
fn(*(v.(*T)))
})
err = db.Select(ctx, &v, s, args...)
return
}
func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error {
return findScanner[T](db, ctx, fn, q)
}
func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error {
return findScanner[T](globalBb, ctx, fn, q)
}
func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) {
return gets[T](globalBb, ctx, q)
}
func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) {
return gets[T](db, ctx, q)
}
func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) {
s, args, err := BuildQuerySql[T](q)
if err != nil {
return
}
err = db.Get(ctx, &r, s, args...)
return return
} }

View File

@ -1,571 +1 @@
package model package model
import (
"context"
"database/sql"
"fmt"
"github.com/fthvgb1/wp-go/helper/number"
"github.com/fthvgb1/wp-go/helper/slice"
"reflect"
"strconv"
"strings"
"testing"
)
func TestFinds(t *testing.T) {
type args struct {
ctx context.Context
q *QueryCondition
}
type testCase[T Model] struct {
name string
args args
wantR []T
wantErr bool
}
tests := []testCase[post]{
{
name: "t1",
args: args{
ctx: context.Background(),
q: Conditions(
Where(SqlBuilder{
{"post_status", "publish"}, {"ID", "in", ""}},
),
Order(SqlBuilder{{"ID", "desc"}}),
Offset(10),
Limit(10),
In([][]any{slice.ToAnySlice(number.Range(1, 1000, 1))}...),
),
},
wantR: func() []post {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, 1000, 1), strconv.Itoa), ",")+") order by ID desc limit 10 offset 10 ")
if err != nil {
panic(err)
}
return r
}(),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := Finds[post](tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("Findx() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("Findx() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
func TestChunkFind(t *testing.T) {
type args struct {
ctx context.Context
perLimit int
q *QueryCondition
}
type testCase[T Model] struct {
name string
args args
wantR []T
wantErr bool
}
n := 500
tests := []testCase[post]{
{
name: "in,orderBy",
args: args{
ctx: ctx,
q: Conditions(
Where(SqlBuilder{{
"post_status", "publish",
}, {"ID", "in", ""}}),
Order(SqlBuilder{{"ID", "desc"}}),
In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...),
),
perLimit: 20,
},
wantR: func() []post {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc")
if err != nil {
panic(err)
}
return r
}(),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := ChunkFind[post](tt.args.ctx, tt.args.perLimit, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("ChunkFind() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("ChunkFind() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
func TestChunk(t *testing.T) {
type args[T Model, R any] struct {
ctx context.Context
perLimit int
fn func(rows T) (R, bool)
q *QueryCondition
}
type testCase[T Model, R any] struct {
name string
args args[T, R]
wantR []R
wantErr bool
}
n := 500
tests := []testCase[post, uint64]{
{
name: "t1",
args: args[post, uint64]{
ctx: ctx,
perLimit: 20,
fn: func(t post) (uint64, bool) {
if t.Id > 300 {
return t.Id, true
}
return 0, false
},
q: Conditions(
Where(SqlBuilder{{
"post_status", "publish",
}, {"ID", "in", ""}}),
Order(SqlBuilder{{"ID", "desc"}}),
In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...),
),
},
wantR: func() []uint64 {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc")
if err != nil {
panic(err)
}
return slice.FilterAndMap(r, func(t post) (uint64, bool) {
if t.Id <= 300 {
return 0, false
}
return t.Id, true
})
}(),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := Chunk[post](tt.args.ctx, tt.args.perLimit, tt.args.fn, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("Chunk() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("Chunk() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
func TestPagination(t *testing.T) {
type args struct {
ctx context.Context
q *QueryCondition
}
type testCase[T Model] struct {
name string
args args
want []T
want1 int
wantErr bool
}
tests := []testCase[post]{
{
name: "t1",
args: args{
ctx: ctx,
q: Conditions(
Where(SqlBuilder{
{"ID", "in", ""},
}),
Page(1),
Limit(5),
In([][]any{slice.ToAnySlice(number.Range(431, 440, 1))}...),
),
},
want: func() (r []post) {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...)
if err != nil && err != sql.ErrNoRows {
panic(err)
} else if err == sql.ErrNoRows {
err = nil
}
return
}(),
want1: 10,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := Pagination[post](tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("Pagination() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Pagination() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Pagination() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestColumn(t *testing.T) {
type args[V Model, T any] struct {
ctx context.Context
fn func(V) (T, bool)
q *QueryCondition
}
type testCase[V Model, T any] struct {
name string
args args[V, T]
wantR []T
wantErr bool
}
tests := []testCase[post, uint64]{
{
name: "t1",
args: args[post, uint64]{
ctx: ctx,
fn: func(t post) (uint64, bool) {
return t.Id, true
},
q: Conditions(
Where(SqlBuilder{
{"ID", "<", "200", "int"},
}),
),
},
wantR: []uint64{63, 64, 190, 193},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := Column[post](tt.args.ctx, tt.args.fn, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("Column() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("Column() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
type options struct {
OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"`
OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"`
OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"`
Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"`
}
func (w options) PrimaryKey() string {
return "option_id"
}
func (w options) Table() string {
return "wp_options"
}
func Test_getField(t *testing.T) {
{
name := "string"
db := glob
field := "option_value"
q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}}))
wantR := "记录并见证自己的成长"
wantErr := false
t.Run(name, func(t *testing.T) {
gotR, err := getField[options](db, ctx, field, q)
if (err != nil) != wantErr {
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
return
}
if !reflect.DeepEqual(gotR, wantR) {
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
}
})
}
{
name := "t2"
db := glob
field := "option_id"
q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}}))
wantR := "3"
wantErr := false
t.Run(name, func(t *testing.T) {
gotR, err := getField[options](db, ctx, field, q)
if (err != nil) != wantErr {
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
return
}
if !reflect.DeepEqual(gotR, wantR) {
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
}
})
}
{
name := "count(*)"
db := glob
field := "count(*)"
q := Conditions()
wantR := "386"
wantErr := false
t.Run(name, func(t *testing.T) {
gotR, err := getField[options](db, ctx, field, q)
if (err != nil) != wantErr {
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
return
}
if !reflect.DeepEqual(gotR, wantR) {
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
}
})
}
}
func Test_getToStringMap(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q *QueryCondition
}
tests := []struct {
name string
args args
wantR map[string]string
wantErr bool
}{
{
name: "t1",
args: args{
db: glob,
ctx: ctx,
q: Conditions(Where(SqlBuilder{{"option_name", "users_can_register"}})),
},
wantR: map[string]string{
"option_id": "5",
"option_value": "0",
"option_name": "users_can_register",
"autoload": "yes",
},
},
{
name: "t2",
args: args{
db: glob,
ctx: ctx,
q: Conditions(
Where(SqlBuilder{{"option_name", "users_can_register"}}),
Fields("option_id id"),
),
},
wantR: map[string]string{
"id": "5",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := getToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("getToStringMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("getToStringMap() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
func Test_findToStringMap(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q *QueryCondition
}
tests := []struct {
name string
args args
wantR []map[string]string
wantErr bool
}{
{
name: "t1",
args: args{
db: glob,
ctx: ctx,
q: Conditions(Where(SqlBuilder{{"option_id", "5"}})),
},
wantR: []map[string]string{{
"option_id": "5",
"option_value": "0",
"option_name": "users_can_register",
"autoload": "yes",
}},
wantErr: false,
},
{
name: "t2",
args: args{
db: glob,
ctx: ctx,
q: Conditions(
Where(SqlBuilder{{"option_id", "5"}}),
Fields("option_value,option_name"),
),
},
wantR: []map[string]string{{
"option_value": "0",
"option_name": "users_can_register",
}},
wantErr: false,
},
{
name: "t3",
args: args{
db: glob,
ctx: ctx,
q: Conditions(
Where(SqlBuilder{{"option_id", "5"}}),
Fields("option_value v,option_name k"),
),
},
wantR: []map[string]string{{
"v": "0",
"k": "users_can_register",
}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := findToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("findToStringMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("findToStringMap() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}
func Test_findScanner(t *testing.T) {
type args[T Model] struct {
db dbQuery
ctx context.Context
fn func(T)
q *QueryCondition
}
type testCase[T Model] struct {
name string
args args[T]
wantErr bool
}
tests := []testCase[options]{
{
name: "t1",
args: args[options]{glob, ctx, func(t options) {
fmt.Println(t)
}, Conditions(Where(SqlBuilder{{"option_id", "<", "10", "int"}}))},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := findScanner[options](tt.args.db, tt.args.ctx, tt.args.fn, tt.args.q); (err != nil) != tt.wantErr {
t.Errorf("findScanner() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func BenchmarkScannerXX(b *testing.B) {
for i := 0; i < b.N; i++ {
err := findScanner[options](glob, ctx, func(t options) {
_ = t
//fmt.Println(t)
}, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}})))
if err != nil {
panic(err)
}
}
}
func BenchmarkFindsXX(b *testing.B) {
for i := 0; i < b.N; i++ {
r, err := finds[options](glob, ctx, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}})))
if err != nil {
panic(err)
}
for _, o := range r {
_ = o
//fmt.Println(o)
}
}
}
func Test_gets(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q *QueryCondition
}
type testCase[T Model] struct {
name string
args args
wantR T
wantErr bool
}
tests := []testCase[options]{
{
name: "t1",
args: args{
db: glob,
ctx: ctx,
q: Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
},
wantR: options{3, "blogname", "记录并见证自己的成长", "yes"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotR, err := gets[options](tt.args.db, tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr {
t.Errorf("gets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("gets() gotR = %v, want %v", gotR, tt.wantR)
}
})
}
}

View File

@ -9,14 +9,14 @@ import (
"strings" "strings"
) )
type SqlxQuery struct { type SqlxQuery[T any] struct {
sqlx *sqlx.DB sqlx *sqlx.DB
UniversalDb UniversalDb[T]
} }
func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery { func NewSqlxQuery[T any](sqlx *sqlx.DB, u UniversalDb[T]) *SqlxQuery[T] {
s := &SqlxQuery{sqlx: sqlx, UniversalDb: u} s := &SqlxQuery[T]{sqlx: sqlx, UniversalDb: u}
if u.selects == nil { if u.selects == nil {
s.UniversalDb.selects = s.Selects s.UniversalDb.selects = s.Selects
} }
@ -26,52 +26,56 @@ func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery {
return s return s
} }
func SetSelect(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) { func SetSelect[T any](db *SqlxQuery[T], fn QuerySelect[T]) {
db.selects = fn db.selects = fn
} }
func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) { func SetGet[T any](db *SqlxQuery[T], fn QueryGet[T]) {
db.gets = fn db.gets = fn
} }
func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error { func (s *SqlxQuery[T]) Selects(ctx context.Context, sql string, params ...any) (r []T, err error) {
v := ctx.Value("handle=>") v := ctx.Value("handle=>")
if v != nil { if v != nil {
vv, ok := v.(string) vv, ok := v.(string)
if ok && vv != "" { if ok && vv != "" {
switch vv { switch vv {
case "string": case "string":
return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) //return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...)
case "scanner": case "scanner":
fn := ctx.Value("fn") fn := ctx.Value("fn")
return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any))) return nil, Scanner[T](s.sqlx, sql, params...)(fn.(func(T)))
} }
} }
} }
return r.sqlx.Select(dest, sql, params...) //var a T
err = s.sqlx.Select(&r, sql, params...)
return
} }
func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { func (s *SqlxQuery[T]) Gets(ctx context.Context, sql string, params ...any) (r T, err error) {
v := ctx.Value("handle=>") v := ctx.Value("handle=>")
if v != nil { if v != nil {
vv, ok := v.(string) vv, ok := v.(string)
if ok && vv != "" { if ok && vv != "" {
switch vv { switch vv {
case "string": case "string":
return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) //return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...)
} }
} }
} }
return r.sqlx.Get(dest, sql, params...) err = s.sqlx.Get(&r, sql, params...)
return
} }
func Scanner[T any](db *sqlx.DB, v T, s string, params ...any) func(func(T)) error { func Scanner[T any](db *sqlx.DB, s string, params ...any) func(func(T)) error {
var v T
return func(fn func(T)) error { return func(fn func(T)) error {
rows, err := db.Queryx(s, params...) rows, err := db.Queryx(s, params...)
if err != nil { if err != nil {
return err return err
} }
for rows.Next() { for rows.Next() {
err = rows.StructScan(v) err = rows.StructScan(&v)
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,19 +2,15 @@ package model
import "context" import "context"
type UniversalDb struct { type UniversalDb[T any] struct {
selects QueryFn selects QuerySelect[T]
gets QueryFn gets QueryGet[T]
} }
func NewUniversalDb(selects QueryFn, gets QueryFn) UniversalDb { func (u *UniversalDb[T]) Select(ctx context.Context, s string, a ...any) ([]T, error) {
return UniversalDb{selects: selects, gets: gets} return u.selects(ctx, s, a...)
} }
func (u UniversalDb) Select(ctx context.Context, a any, s string, args ...any) error { func (u *UniversalDb[T]) Get(ctx context.Context, s string, a ...any) (T, error) {
return u.selects(ctx, a, s, args...) return u.gets(ctx, s, a...)
}
func (u UniversalDb) Get(ctx context.Context, a any, s string, args ...any) error {
return u.gets(ctx, a, s, args...)
} }