Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
9ee402af4b | ||
|
00f788fad5 | ||
|
1ecfa19fd4 |
|
@ -2,6 +2,7 @@ package model
|
|||
|
||||
type QueryCondition struct {
|
||||
where ParseWhere
|
||||
from string
|
||||
fields string
|
||||
group string
|
||||
order SqlBuilder
|
||||
|
@ -37,6 +38,12 @@ func Fields(fields string) Condition {
|
|||
}
|
||||
}
|
||||
|
||||
func From(from string) Condition {
|
||||
return func(c *QueryCondition) {
|
||||
c.from = from
|
||||
}
|
||||
}
|
||||
|
||||
func Group(group string) Condition {
|
||||
return func(c *QueryCondition) {
|
||||
c.group = group
|
||||
|
|
|
@ -5,13 +5,14 @@ import (
|
|||
)
|
||||
|
||||
var _ ParseWhere = SqlBuilder{}
|
||||
var globalBb dbQuery
|
||||
var globalBb dbQuery[Model]
|
||||
|
||||
func InitDB(db dbQuery) {
|
||||
func InitDB(db dbQuery[Model]) {
|
||||
globalBb = db
|
||||
}
|
||||
|
||||
type QueryFn func(context.Context, any, string, ...any) error
|
||||
type QuerySelect[T any] func(context.Context, string, ...any) ([]T, error)
|
||||
type QueryGet[T any] func(context.Context, string, ...any) (T, error)
|
||||
|
||||
type Model interface {
|
||||
PrimaryKey() string
|
||||
|
@ -22,9 +23,9 @@ type ParseWhere interface {
|
|||
ParseWhere(*[][]any) (string, []any, error)
|
||||
}
|
||||
|
||||
type dbQuery interface {
|
||||
Select(context.Context, any, string, ...any) error
|
||||
Get(context.Context, any, string, ...any) error
|
||||
type dbQuery[T any] interface {
|
||||
Select(context.Context, string, ...any) ([]T, error)
|
||||
Get(context.Context, string, ...any) (T, error)
|
||||
}
|
||||
|
||||
type SqlBuilder [][]string
|
||||
|
|
|
@ -1,40 +1,28 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/fthvgb1/wp-go/helper/slice"
|
||||
str "github.com/fthvgb1/wp-go/helper/strings"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (w SqlBuilder) parseField(ss []string, s *strings.Builder) {
|
||||
func (w SqlBuilder) parseField(ss []string, s *str.Builder) {
|
||||
if strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") {
|
||||
s.WriteString("`")
|
||||
sx := strings.Split(ss[0], ".")
|
||||
s.WriteString(sx[0])
|
||||
s.WriteString("`.`")
|
||||
s.WriteString(sx[1])
|
||||
s.WriteString("`")
|
||||
s.Sprintf("`%s`.`%s`", sx[0], sx[1])
|
||||
} else if !strings.Contains(ss[0], ".") && !strings.Contains(ss[0], "(") {
|
||||
s.WriteString("`")
|
||||
s.WriteString(ss[0])
|
||||
s.WriteString("`")
|
||||
s.Sprintf("`%s`", ss[0])
|
||||
} else {
|
||||
s.WriteString(ss[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (w SqlBuilder) parseIn(ss []string, s *strings.Builder, c *int, args *[]any, in *[][]any) (t bool) {
|
||||
func (w SqlBuilder) parseIn(ss []string, s *str.Builder, c *int, args *[]any, in *[][]any) (t bool) {
|
||||
if slice.IsContained(ss[1], []string{"in", "not in"}) && len(*in) > 0 {
|
||||
s.WriteString(" (")
|
||||
for _, p := range (*in)[*c] {
|
||||
s.WriteString("?,")
|
||||
*args = append(*args, p)
|
||||
}
|
||||
sx := s.String()
|
||||
s.Reset()
|
||||
s.WriteString(strings.TrimRight(sx, ","))
|
||||
s.WriteString(")")
|
||||
sss := strings.Repeat("?,", len((*in)[*c]))
|
||||
s.Sprintf("(%s)", strings.TrimRight(sss, ","))
|
||||
*args = append(*args, (*in)[*c]...)
|
||||
*c++
|
||||
t = true
|
||||
}
|
||||
|
@ -78,18 +66,18 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error {
|
|||
//
|
||||
// {{"and","field","=","num1","int","or","field","=","num2","int"}} => where (field = num1 or field = num2')
|
||||
func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
|
||||
var s strings.Builder
|
||||
var s = str.NewBuilder()
|
||||
args := make([]any, 0, len(w))
|
||||
c := 0
|
||||
for _, ss := range w {
|
||||
if len(ss) == 2 {
|
||||
w.parseField(ss, &s)
|
||||
w.parseField(ss, s)
|
||||
s.WriteString("=? and ")
|
||||
args = append(args, ss[1])
|
||||
} else if len(ss) >= 3 && len(ss) < 5 {
|
||||
w.parseField(ss, &s)
|
||||
w.parseField(ss, s)
|
||||
s.WriteString(ss[1])
|
||||
if w.parseIn(ss, &s, &c, &args, in) {
|
||||
if w.parseIn(ss, s, &c, &args, in) {
|
||||
s.WriteString(" and ")
|
||||
continue
|
||||
}
|
||||
|
@ -107,15 +95,14 @@ func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
|
|||
if strings.Contains(st, "and ") && ss[start] == "or" {
|
||||
st = strings.TrimRight(st, "and ")
|
||||
s.Reset()
|
||||
s.WriteString(st)
|
||||
s.WriteString(fmt.Sprintf(" %s ", ss[start]))
|
||||
s.Sprintf("%s %s ", st, ss[start])
|
||||
}
|
||||
if i == 0 {
|
||||
s.WriteString("( ")
|
||||
}
|
||||
w.parseField(ss[start+1:end], &s)
|
||||
w.parseField(ss[start+1:end], s)
|
||||
s.WriteString(ss[start+2])
|
||||
if w.parseIn(ss[start+1:end], &s, &c, &args, in) {
|
||||
if w.parseIn(ss[start+1:end], s, &c, &args, in) {
|
||||
s.WriteString(" and ")
|
||||
continue
|
||||
}
|
||||
|
@ -128,15 +115,13 @@ func (w SqlBuilder) ParseWhere(in *[][]any) (string, []any, error) {
|
|||
st := s.String()
|
||||
st = strings.TrimRight(st, "and ")
|
||||
s.Reset()
|
||||
s.WriteString(st)
|
||||
s.WriteString(") and ")
|
||||
s.Sprintf("%s ) and ", st)
|
||||
}
|
||||
}
|
||||
ss := strings.TrimRight(s.String(), "and ")
|
||||
if ss != "" {
|
||||
s.Reset()
|
||||
s.WriteString(" where ")
|
||||
s.WriteString(ss)
|
||||
s.Sprintf(" where %s", ss)
|
||||
ss = s.String()
|
||||
}
|
||||
if len(*in) > c {
|
||||
|
|
186
model/query.go
186
model/query.go
|
@ -1,185 +1,25 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"golang.org/x/exp/constraints"
|
||||
"math/rand"
|
||||
"strings"
|
||||
)
|
||||
import "context"
|
||||
|
||||
func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) {
|
||||
var rr T
|
||||
var w string
|
||||
var args []any
|
||||
if where != nil {
|
||||
w, args, err = where.ParseWhere(&in)
|
||||
func finds[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) {
|
||||
s, args, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return r, total, err
|
||||
return nil, err
|
||||
}
|
||||
return db.Select(ctx, s, args...)
|
||||
}
|
||||
|
||||
h := ""
|
||||
if having != nil {
|
||||
hh, arg, err := having.ParseWhere(&in)
|
||||
func scanners[T Model](db dbQuery[T], ctx context.Context, q *QueryCondition) ([]T, error) {
|
||||
s, args, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return r, total, err
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, arg...)
|
||||
h = strings.Replace(hh, " where", " having", 1)
|
||||
}
|
||||
|
||||
n := struct {
|
||||
N int `db:"n" json:"n"`
|
||||
}{}
|
||||
groupBy := ""
|
||||
if group != "" {
|
||||
g := strings.Builder{}
|
||||
g.WriteString(" group by ")
|
||||
g.WriteString(group)
|
||||
groupBy = g.String()
|
||||
}
|
||||
if having != nil {
|
||||
tm := map[string]struct{}{}
|
||||
for _, s := range strings.Split(group, ",") {
|
||||
tm[s] = struct{}{}
|
||||
}
|
||||
for _, ss := range having {
|
||||
if _, ok := tm[ss[0]]; !ok {
|
||||
group = fmt.Sprintf("%s,%s", group, ss[0])
|
||||
}
|
||||
}
|
||||
group = strings.Trim(group, ",")
|
||||
}
|
||||
j := join.parseJoin()
|
||||
if group == "" {
|
||||
tpx := "select count(*) n from %s %s %s limit 1"
|
||||
sq := fmt.Sprintf(tpx, rr.Table(), j, w)
|
||||
err = db.Get(ctx, &n, sq, args...)
|
||||
} else {
|
||||
tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s"
|
||||
sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int()))
|
||||
err = db.Get(ctx, &n, sq, args...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n.N == 0 {
|
||||
return
|
||||
}
|
||||
total = n.N
|
||||
offset := 0
|
||||
if page > 1 {
|
||||
offset = (page - 1) * pageSize
|
||||
}
|
||||
if offset >= total {
|
||||
return
|
||||
}
|
||||
tp := "select %s from %s %s %s %s %s %s limit %d,%d"
|
||||
sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize)
|
||||
err = db.Select(ctx, &r, sq, args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) {
|
||||
r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...)
|
||||
return
|
||||
}
|
||||
|
||||
func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) {
|
||||
var r T
|
||||
sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey())
|
||||
err := globalBb.Get(ctx, &r, sq, id)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) {
|
||||
s, args, err := BuildQuerySql[T](&QueryCondition{
|
||||
where: where,
|
||||
fields: fields,
|
||||
order: order,
|
||||
in: in,
|
||||
limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = globalBb.Get(ctx, &r, s, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) {
|
||||
var r T
|
||||
var w string
|
||||
var args []any
|
||||
var err error
|
||||
if where != nil {
|
||||
w, args, err = where.ParseWhere(&in)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
tp := "select %s from %s %s order by %s desc limit 1"
|
||||
sq := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey())
|
||||
err = globalBb.Get(ctx, &r, sq, args...)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) {
|
||||
s, args, err := BuildQuerySql[T](&QueryCondition{
|
||||
where: where,
|
||||
fields: fields,
|
||||
in: in,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = globalBb.Select(ctx, &r, s, args...)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error) {
|
||||
ctx = context.WithValue(ctx, "handle=>", "scanner")
|
||||
var r []T
|
||||
var rr T
|
||||
sql = strings.Replace(sql, "{table}", rr.Table(), -1)
|
||||
err := globalBb.Select(ctx, &r, sql, params...)
|
||||
if err != nil {
|
||||
ctx = context.WithValue(ctx, "fn", func(t T) {
|
||||
r = append(r, t)
|
||||
})
|
||||
_, err = db.Select(ctx, s, args...)
|
||||
return r, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) {
|
||||
q := QueryCondition{
|
||||
where: where,
|
||||
fields: fields,
|
||||
group: group,
|
||||
order: order,
|
||||
join: join,
|
||||
having: having,
|
||||
limit: limit,
|
||||
in: in,
|
||||
}
|
||||
s, args, err := BuildQuerySql[T](&q)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = globalBb.Select(ctx, &r, s, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func Get[T Model](ctx context.Context, sql string, params ...any) (r T, err error) {
|
||||
sql = strings.Replace(sql, "{table}", r.Table(), -1)
|
||||
err = globalBb.Get(ctx, &r, sql, params...)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,517 +2,159 @@ package model
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/fthvgb1/wp-go/helper/number"
|
||||
"github.com/fthvgb1/wp-go/helper/slice"
|
||||
"github.com/fthvgb1/wp-go/safety"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type post struct {
|
||||
Id uint64 `gorm:"column:ID" db:"ID" json:"ID" form:"ID"`
|
||||
PostAuthor uint64 `gorm:"column:post_author" db:"post_author" json:"post_author" form:"post_author"`
|
||||
PostDate time.Time `gorm:"column:post_date" db:"post_date" json:"post_date" form:"post_date"`
|
||||
PostDateGmt time.Time `gorm:"column:post_date_gmt" db:"post_date_gmt" json:"post_date_gmt" form:"post_date_gmt"`
|
||||
PostContent string `gorm:"column:post_content" db:"post_content" json:"post_content" form:"post_content"`
|
||||
PostTitle string `gorm:"column:post_title" db:"post_title" json:"post_title" form:"post_title"`
|
||||
PostExcerpt string `gorm:"column:post_excerpt" db:"post_excerpt" json:"post_excerpt" form:"post_excerpt"`
|
||||
PostStatus string `gorm:"column:post_status" db:"post_status" json:"post_status" form:"post_status"`
|
||||
CommentStatus string `gorm:"column:comment_status" db:"comment_status" json:"comment_status" form:"comment_status"`
|
||||
PingStatus string `gorm:"column:ping_status" db:"ping_status" json:"ping_status" form:"ping_status"`
|
||||
PostPassword string `gorm:"column:post_password" db:"post_password" json:"post_password" form:"post_password"`
|
||||
PostName string `gorm:"column:post_name" db:"post_name" json:"post_name" form:"post_name"`
|
||||
ToPing string `gorm:"column:to_ping" db:"to_ping" json:"to_ping" form:"to_ping"`
|
||||
Pinged string `gorm:"column:pinged" db:"pinged" json:"pinged" form:"pinged"`
|
||||
PostModified time.Time `gorm:"column:post_modified" db:"post_modified" json:"post_modified" form:"post_modified"`
|
||||
PostModifiedGmt time.Time `gorm:"column:post_modified_gmt" db:"post_modified_gmt" json:"post_modified_gmt" form:"post_modified_gmt"`
|
||||
PostContentFiltered string `gorm:"column:post_content_filtered" db:"post_content_filtered" json:"post_content_filtered" form:"post_content_filtered"`
|
||||
PostParent uint64 `gorm:"column:post_parent" db:"post_parent" json:"post_parent" form:"post_parent"`
|
||||
Guid string `gorm:"column:guid" db:"guid" json:"guid" form:"guid"`
|
||||
MenuOrder int `gorm:"column:menu_order" db:"menu_order" json:"menu_order" form:"menu_order"`
|
||||
PostType string `gorm:"column:post_type" db:"post_type" json:"post_type" form:"post_type"`
|
||||
PostMimeType string `gorm:"column:post_mime_type" db:"post_mime_type" json:"post_mime_type" form:"post_mime_type"`
|
||||
CommentCount int64 `gorm:"column:comment_count" db:"comment_count" json:"comment_count" form:"comment_count"`
|
||||
}
|
||||
|
||||
type user struct {
|
||||
Id uint64 `gorm:"column:ID" db:"ID" json:"ID"`
|
||||
UserLogin string `gorm:"column:user_login" db:"user_login" json:"user_login"`
|
||||
UserPass string `gorm:"column:user_pass" db:"user_pass" json:"user_pass"`
|
||||
UserNicename string `gorm:"column:user_nicename" db:"user_nicename" json:"user_nicename"`
|
||||
UserEmail string `gorm:"column:user_email" db:"user_email" json:"user_email"`
|
||||
UserUrl string `gorm:"column:user_url" db:"user_url" json:"user_url"`
|
||||
UserRegistered time.Time `gorm:"column:user_registered" db:"user_registered" json:"user_registered"`
|
||||
UserActivationKey string `gorm:"column:user_activation_key" db:"user_activation_key" json:"user_activation_key"`
|
||||
UserStatus int `gorm:"column:user_status" db:"user_status" json:"user_status"`
|
||||
DisplayName string `gorm:"column:display_name" db:"display_name" json:"display_name"`
|
||||
}
|
||||
|
||||
type termTaxonomy struct {
|
||||
TermTaxonomyId uint64 `gorm:"column:term_taxonomy_id" db:"term_taxonomy_id" json:"term_taxonomy_id" form:"term_taxonomy_id"`
|
||||
TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"`
|
||||
Taxonomy string `gorm:"column:taxonomy" db:"taxonomy" json:"taxonomy" form:"taxonomy"`
|
||||
Description string `gorm:"column:description" db:"description" json:"description" form:"description"`
|
||||
Parent uint64 `gorm:"column:parent" db:"parent" json:"parent" form:"parent"`
|
||||
Count int64 `gorm:"column:count" db:"count" json:"count" form:"count"`
|
||||
}
|
||||
|
||||
type terms struct {
|
||||
TermId uint64 `gorm:"column:term_id" db:"term_id" json:"term_id" form:"term_id"`
|
||||
Name string `gorm:"column:name" db:"name" json:"name" form:"name"`
|
||||
Slug string `gorm:"column:slug" db:"slug" json:"slug" form:"slug"`
|
||||
TermGroup int64 `gorm:"column:term_group" db:"term_group" json:"term_group" form:"term_group"`
|
||||
}
|
||||
|
||||
func (t terms) PrimaryKey() string {
|
||||
return "term_id"
|
||||
}
|
||||
func (t terms) Table() string {
|
||||
return "wp_terms"
|
||||
}
|
||||
|
||||
func (w termTaxonomy) PrimaryKey() string {
|
||||
return "term_taxonomy_id"
|
||||
}
|
||||
|
||||
func (w termTaxonomy) Table() string {
|
||||
return "wp_term_taxonomy"
|
||||
}
|
||||
|
||||
func (u user) Table() string {
|
||||
return "wp_users"
|
||||
}
|
||||
|
||||
func (u user) PrimaryKey() string {
|
||||
return "ID"
|
||||
}
|
||||
|
||||
func (p post) PrimaryKey() string {
|
||||
return "ID"
|
||||
}
|
||||
|
||||
func (p post) Table() string {
|
||||
return "wp_posts"
|
||||
}
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
var glob *SqlxQuery
|
||||
var glob = safety.NewMap[string, dbQuery[Model]]()
|
||||
var dbMap = sync.Map{}
|
||||
|
||||
var sq *sqlx.DB
|
||||
|
||||
func anyDb[T Model]() *SqlxQuery[T] {
|
||||
var a T
|
||||
db, ok := dbMap.Load(a.Table())
|
||||
if ok {
|
||||
return db.(*SqlxQuery[T])
|
||||
}
|
||||
dbb := NewSqlxQuery[T](sq, UniversalDb[T]{nil, nil})
|
||||
dbMap.Store(a.Table(), dbb)
|
||||
return dbb
|
||||
}
|
||||
|
||||
func init() {
|
||||
db, err := sqlx.Open("mysql", "root:root@tcp(192.168.66.47:3306)/wordpress?charset=utf8mb4&parseTime=True&loc=Local")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
glob = NewSqlxQuery(db, NewUniversalDb(nil, nil))
|
||||
InitDB(glob)
|
||||
}
|
||||
func TestFind(t *testing.T) {
|
||||
type args struct {
|
||||
where ParseWhere
|
||||
fields string
|
||||
group string
|
||||
order SqlBuilder
|
||||
join SqlBuilder
|
||||
having SqlBuilder
|
||||
limit int
|
||||
in [][]any
|
||||
}
|
||||
type posts struct {
|
||||
post
|
||||
N int `db:"n"`
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantR []posts
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "in,orderBy",
|
||||
args: args{
|
||||
where: SqlBuilder{{
|
||||
"post_status", "publish",
|
||||
}, {"ID", "in", ""}},
|
||||
fields: "*",
|
||||
group: "",
|
||||
order: SqlBuilder{{"ID", "desc"}},
|
||||
join: nil,
|
||||
having: nil,
|
||||
limit: 0,
|
||||
in: [][]any{{1, 2, 3, 4}},
|
||||
},
|
||||
wantR: func() []posts {
|
||||
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where post_status='publish' and ID in (1,2,3,4) order by ID desc")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "or",
|
||||
args: args{
|
||||
where: SqlBuilder{{
|
||||
"and", "ID", "=", "1", "int",
|
||||
}, {"or", "ID", "=", "2", "int"}},
|
||||
fields: "*",
|
||||
group: "",
|
||||
order: nil,
|
||||
join: nil,
|
||||
having: nil,
|
||||
limit: 0,
|
||||
in: nil,
|
||||
},
|
||||
wantR: func() []posts {
|
||||
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID=1 or ID=2)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "group,having",
|
||||
args: args{
|
||||
where: SqlBuilder{
|
||||
{"ID", "<", "1000", "int"},
|
||||
},
|
||||
fields: "post_status,count(*) n",
|
||||
group: "post_status",
|
||||
order: nil,
|
||||
join: nil,
|
||||
having: SqlBuilder{
|
||||
{"n", ">", "1"},
|
||||
},
|
||||
limit: 0,
|
||||
in: nil,
|
||||
},
|
||||
wantR: func() []posts {
|
||||
r, err := Select[posts](ctx, "select post_status,count(*) n from "+post{}.Table()+" where ID<1000 group by post_status having n>1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "or、多个in",
|
||||
args: args{
|
||||
where: SqlBuilder{
|
||||
{"and", "ID", "in", "", "", "or", "ID", "in", "", ""},
|
||||
{"or", "post_status", "=", "publish", "", "and", "post_status", "=", "closed", ""},
|
||||
},
|
||||
fields: "*",
|
||||
group: "",
|
||||
order: nil,
|
||||
join: nil,
|
||||
having: nil,
|
||||
limit: 0,
|
||||
in: [][]any{{1, 2, 3}, {4, 5, 6}},
|
||||
},
|
||||
wantR: func() []posts {
|
||||
r, err := Select[posts](ctx, "select * from "+posts{}.Table()+" where (ID in (1,2,3) or ID in (4,5,6)) or (post_status='publish' and post_status='closed')")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
args: args{
|
||||
where: SqlBuilder{
|
||||
{"b.user_login", "in", ""},
|
||||
{"and", "a.post_type", "=", "post", "", "or", "a.post_type", "=", "page", ""},
|
||||
{"a.comment_count", ">", "0", "int"},
|
||||
{"a.post_status", "publish"},
|
||||
{"e.name", "in", ""},
|
||||
{"d.taxonomy", "category"},
|
||||
},
|
||||
fields: "post_author,count(*) n",
|
||||
group: "a.post_author",
|
||||
order: SqlBuilder{{"n", "desc"}},
|
||||
join: SqlBuilder{
|
||||
{"a", "left join", user{}.Table() + " b", "a.post_author=b.ID"},
|
||||
{"left join", "wp_term_relationships c", "a.Id=c.object_id"},
|
||||
{"left join", termTaxonomy{}.Table() + " d", "c.term_taxonomy_id=d.term_taxonomy_id"},
|
||||
{"left join", terms{}.Table() + " e", "d.term_id=e.term_id"},
|
||||
},
|
||||
having: SqlBuilder{{"n", ">", "0", "int"}},
|
||||
limit: 10,
|
||||
in: [][]any{{"test", "test2"}, {"web", "golang", "php"}},
|
||||
},
|
||||
wantR: func() []posts {
|
||||
r, err := Select[posts](ctx, "select post_author,count(*) n from wp_posts a left join wp_users b on a.post_author=b.ID left join wp_term_relationships c on a.Id=c.object_id left join wp_term_taxonomy d on c.term_taxonomy_id=d.term_taxonomy_id left join wp_terms e on d.term_id=e.term_id where b.user_login in ('test','test2') and b.user_status=0 and (a.post_type='post' or a.post_type='page') and a.comment_count>0 and a.post_status='publish' and e.name in ('web','golang','php') and d.taxonomy='category' group by post_author having n > 0 order by n desc limit 10")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := Find[posts](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.order, tt.args.join, tt.args.having, tt.args.limit, tt.args.in...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("Find() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
sq = db
|
||||
//glob = NewSqlxQuery(db, NewUniversalDb(nil, nil))
|
||||
|
||||
}
|
||||
|
||||
func TestFindOneById(t *testing.T) {
|
||||
type args struct {
|
||||
id int
|
||||
func Test_selects(t *testing.T) {
|
||||
type args[T Model] struct {
|
||||
db dbQuery[T]
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
want post
|
||||
args args[T]
|
||||
want []T
|
||||
wantErr bool
|
||||
}{
|
||||
}
|
||||
tests := []testCase[options]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
1,
|
||||
args: args[options]{
|
||||
anyDb[options](),
|
||||
ctx,
|
||||
Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
|
||||
},
|
||||
want: func() post {
|
||||
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where ID=?", 1)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
panic(err)
|
||||
} else if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := FindOneById[post](ctx, tt.args.id)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
got, err := finds[options](tt.args.db, tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FindOneById() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("finds() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("FindOneById() got = %v, want %v", got, tt.want)
|
||||
t.Errorf("finds() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstOne(t *testing.T) {
|
||||
type args struct {
|
||||
where ParseWhere
|
||||
fields string
|
||||
order SqlBuilder
|
||||
in [][]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want post
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
where: SqlBuilder{{"post_status", "publish"}},
|
||||
fields: "*",
|
||||
order: SqlBuilder{{"ID", "desc"}},
|
||||
in: nil,
|
||||
},
|
||||
wantErr: false,
|
||||
want: func() post {
|
||||
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by ID desc limit 1")
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
panic(err)
|
||||
} else if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("FirstOne() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastOne(t *testing.T) {
|
||||
type args struct {
|
||||
where ParseWhere
|
||||
fields string
|
||||
in [][]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want post
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
where: SqlBuilder{{
|
||||
"post_status", "publish",
|
||||
}},
|
||||
fields: "*",
|
||||
in: nil,
|
||||
},
|
||||
want: func() post {
|
||||
r, err := Get[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' order by "+post{}.PrimaryKey()+" desc limit 1")
|
||||
func BenchmarkSelectXX(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := finds[options](anyDb[options](), ctx, Conditions(
|
||||
Where(SqlBuilder{{"option_id", "<", "50", "int"}}),
|
||||
//In(slice.ToAnySlice(number.Range[uint64](1, 50, 1))),
|
||||
))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := LastOne[post](ctx, tt.args.where, tt.args.fields, tt.args.in...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LastOne() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("LastOne() got = %v, want %v", got, tt.want)
|
||||
func BenchmarkScannerXX(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := scanners[options](anyDb[options](), ctx, Conditions(
|
||||
Where(SqlBuilder{{"option_id", "<", "50", "int"}}),
|
||||
//In(slice.ToAnySlice(number.Range[uint64](1, 50, 1))),
|
||||
))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleFind(t *testing.T) {
|
||||
type args struct {
|
||||
where ParseWhere
|
||||
fields string
|
||||
in [][]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []post
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
where: SqlBuilder{
|
||||
{"ID", "in", ""},
|
||||
},
|
||||
fields: "*",
|
||||
in: [][]any{{1, 2}},
|
||||
},
|
||||
want: func() (r []post) {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?)", 1, 2)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
func BenchmarkSqlXQueryXX(b *testing.B) {
|
||||
var r []options
|
||||
/*var r []options
|
||||
x := number.Range[uint64](1, 50, 1)
|
||||
j := strings.TrimRight(strings.Repeat("?,", 50), ",")
|
||||
s := str.NewBuilder()
|
||||
s.Sprintf("select * from wp_options where option_id in (%s)", j)
|
||||
ss := s.String()
|
||||
a := slice.ToAnySlice(x)*/
|
||||
ss := "select * from wp_options where option_id < ?"
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := sq.Select(&r, ss, 50)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := SimpleFind[post](ctx, tt.args.where, tt.args.fields, tt.args.in...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SimpleFind() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SimpleFind() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimplePagination(t *testing.T) {
|
||||
type args struct {
|
||||
where ParseWhere
|
||||
fields string
|
||||
group string
|
||||
page int
|
||||
pageSize int
|
||||
order SqlBuilder
|
||||
join SqlBuilder
|
||||
having SqlBuilder
|
||||
in [][]any
|
||||
type options struct {
|
||||
OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"`
|
||||
OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"`
|
||||
OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"`
|
||||
Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"`
|
||||
}
|
||||
tests := []struct {
|
||||
|
||||
func (w options) PrimaryKey() string {
|
||||
return "option_id"
|
||||
}
|
||||
|
||||
func (w options) Table() string {
|
||||
return "wp_options"
|
||||
}
|
||||
|
||||
func Test_scanners(t *testing.T) {
|
||||
type args[T Model] struct {
|
||||
db dbQuery[T]
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
wantR []post
|
||||
wantTotal int
|
||||
args args[T]
|
||||
wantErr bool
|
||||
}{
|
||||
}
|
||||
tests := []testCase[options]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
where: SqlBuilder{
|
||||
{"ID", "in", ""},
|
||||
args: args[options]{
|
||||
anyDb[options](),
|
||||
ctx,
|
||||
Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
|
||||
},
|
||||
fields: "*",
|
||||
group: "",
|
||||
page: 1,
|
||||
pageSize: 5,
|
||||
order: nil,
|
||||
join: nil,
|
||||
having: nil,
|
||||
in: [][]any{slice.ToAnySlice(number.Range(431, 440, 1))},
|
||||
},
|
||||
wantR: func() (r []post) {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
panic(err)
|
||||
} else if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}(),
|
||||
wantTotal: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, gotTotal, err := SimplePagination[post](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.page, tt.args.pageSize, tt.args.order, tt.args.join, tt.args.having, tt.args.in...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SimplePagination() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("SimplePagination() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
if gotTotal != tt.wantTotal {
|
||||
t.Errorf("SimplePagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal)
|
||||
if _, err := scanners[options](tt.args.db, tt.args.ctx, tt.args.q); (err != nil) != tt.wantErr {
|
||||
t.Errorf("scanners() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,222 +1,10 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/fthvgb1/wp-go/helper/slice"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Finds 比 Find 多一个offset
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
|
||||
func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
|
||||
r, err = finds[T](globalBb, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
// FindFromDB 同 Finds 使用指定 db 查询
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
|
||||
func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
|
||||
r, err = finds[T](db, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
|
||||
sq, args, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = db.Select(ctx, &r, sq, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
|
||||
i := 1
|
||||
var rr []T
|
||||
var total int
|
||||
var offset int
|
||||
for {
|
||||
if 1 == i {
|
||||
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
|
||||
} else {
|
||||
q.offset = offset
|
||||
q.limit = perLimit
|
||||
rr, err = finds[T](db, ctx, q)
|
||||
}
|
||||
offset += perLimit
|
||||
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
|
||||
return
|
||||
}
|
||||
r = append(r, rr...)
|
||||
if len(r) >= total {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ChunkFind 分片查询并直接返回所有结果
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
|
||||
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
|
||||
r, err = chunkFind[T](globalBb, ctx, perLimit, q)
|
||||
return
|
||||
}
|
||||
|
||||
// ChunkFindFromDB 同 ChunkFind
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
|
||||
func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
|
||||
r, err = chunkFind[T](db, ctx, perLimit, q)
|
||||
return
|
||||
}
|
||||
|
||||
// Chunk 分片查询并函数过虑返回新类型的切片
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
|
||||
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
|
||||
r, err = chunk(globalBb, ctx, perLimit, fn, q)
|
||||
return
|
||||
}
|
||||
|
||||
// ChunkFromDB 同 Chunk
|
||||
//
|
||||
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
|
||||
func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
|
||||
r, err = chunk(db, ctx, perLimit, fn, q)
|
||||
return
|
||||
}
|
||||
|
||||
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
|
||||
i := 1
|
||||
var rr []T
|
||||
var count int
|
||||
var total int
|
||||
var offset int
|
||||
for {
|
||||
if 1 == i {
|
||||
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
|
||||
} else {
|
||||
q.offset = offset
|
||||
q.limit = perLimit
|
||||
rr, err = finds[T](db, ctx, q)
|
||||
}
|
||||
offset += perLimit
|
||||
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
|
||||
return
|
||||
}
|
||||
for _, t := range rr {
|
||||
v, ok := fn(t)
|
||||
if ok {
|
||||
r = append(r, v)
|
||||
}
|
||||
}
|
||||
count += len(rr)
|
||||
if count >= total {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Pagination 同 SimplePagination
|
||||
//
|
||||
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
|
||||
func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) {
|
||||
return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
|
||||
}
|
||||
|
||||
// PaginationFromDB 同 Pagination 方便多个db使用
|
||||
//
|
||||
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
|
||||
func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) {
|
||||
return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
|
||||
}
|
||||
|
||||
func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) {
|
||||
return column[V, T](globalBb, ctx, fn, q)
|
||||
}
|
||||
func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
|
||||
return column[V, T](db, ctx, fn, q)
|
||||
}
|
||||
|
||||
func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
|
||||
res, err := finds[V](db, ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r = slice.FilterAndMap(res, fn)
|
||||
return
|
||||
}
|
||||
|
||||
func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) {
|
||||
r, err = getField[T](globalBb, ctx, field, q)
|
||||
return
|
||||
}
|
||||
func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
|
||||
if q.fields == "" || q.fields == "*" {
|
||||
q.fields = field
|
||||
}
|
||||
res, err := getToStringMap[T](db, ctx, q)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
f := strings.Split(field, " ")
|
||||
r, ok := res[f[len(f)-1]]
|
||||
if !ok {
|
||||
err = errors.New("not exists")
|
||||
}
|
||||
return
|
||||
}
|
||||
func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
|
||||
return getField[T](db, ctx, field, q)
|
||||
}
|
||||
|
||||
func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
|
||||
rawSql, in, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, "handle=>", "string")
|
||||
err = db.Get(ctx, &r, rawSql, in...)
|
||||
return
|
||||
}
|
||||
func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
|
||||
r, err = getToStringMap[T](globalBb, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
|
||||
rawSql, in, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, "handle=>", "string")
|
||||
err = db.Select(ctx, &r, rawSql, in...)
|
||||
return
|
||||
}
|
||||
|
||||
func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
|
||||
r, err = findToStringMap[T](globalBb, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
|
||||
r, err = findToStringMap[T](db, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
|
||||
r, err = getToStringMap[T](db, ctx, q)
|
||||
return
|
||||
}
|
||||
|
||||
func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) {
|
||||
var rr T
|
||||
w := ""
|
||||
|
@ -254,44 +42,10 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error)
|
|||
if q.offset > 0 {
|
||||
l = fmt.Sprintf(" %s offset %d", l, q.offset)
|
||||
}
|
||||
r = fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l)
|
||||
return
|
||||
}
|
||||
|
||||
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) {
|
||||
s, args, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "handle=>", "scanner")
|
||||
var v T
|
||||
ctx = context.WithValue(ctx, "fn", func(v any) {
|
||||
fn(*(v.(*T)))
|
||||
})
|
||||
err = db.Select(ctx, &v, s, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error {
|
||||
return findScanner[T](db, ctx, fn, q)
|
||||
}
|
||||
|
||||
func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error {
|
||||
return findScanner[T](globalBb, ctx, fn, q)
|
||||
}
|
||||
|
||||
func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) {
|
||||
return gets[T](globalBb, ctx, q)
|
||||
}
|
||||
func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) {
|
||||
return gets[T](db, ctx, q)
|
||||
}
|
||||
|
||||
func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) {
|
||||
s, args, err := BuildQuerySql[T](q)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = db.Get(ctx, &r, s, args...)
|
||||
table := rr.Table()
|
||||
if q.from != "" {
|
||||
table = q.from
|
||||
}
|
||||
r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,571 +1 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/fthvgb1/wp-go/helper/number"
|
||||
"github.com/fthvgb1/wp-go/helper/slice"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFinds(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
wantR []T
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase[post]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{
|
||||
{"post_status", "publish"}, {"ID", "in", ""}},
|
||||
),
|
||||
Order(SqlBuilder{{"ID", "desc"}}),
|
||||
Offset(10),
|
||||
Limit(10),
|
||||
In([][]any{slice.ToAnySlice(number.Range(1, 1000, 1))}...),
|
||||
),
|
||||
},
|
||||
wantR: func() []post {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, 1000, 1), strconv.Itoa), ",")+") order by ID desc limit 10 offset 10 ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := Finds[post](tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Findx() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("Findx() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkFind(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
perLimit int
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
wantR []T
|
||||
wantErr bool
|
||||
}
|
||||
n := 500
|
||||
tests := []testCase[post]{
|
||||
{
|
||||
name: "in,orderBy",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{{
|
||||
"post_status", "publish",
|
||||
}, {"ID", "in", ""}}),
|
||||
Order(SqlBuilder{{"ID", "desc"}}),
|
||||
In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...),
|
||||
),
|
||||
perLimit: 20,
|
||||
},
|
||||
wantR: func() []post {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := ChunkFind[post](tt.args.ctx, tt.args.perLimit, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ChunkFind() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("ChunkFind() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunk(t *testing.T) {
|
||||
type args[T Model, R any] struct {
|
||||
ctx context.Context
|
||||
perLimit int
|
||||
fn func(rows T) (R, bool)
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model, R any] struct {
|
||||
name string
|
||||
args args[T, R]
|
||||
wantR []R
|
||||
wantErr bool
|
||||
}
|
||||
n := 500
|
||||
tests := []testCase[post, uint64]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args[post, uint64]{
|
||||
ctx: ctx,
|
||||
perLimit: 20,
|
||||
fn: func(t post) (uint64, bool) {
|
||||
if t.Id > 300 {
|
||||
return t.Id, true
|
||||
}
|
||||
return 0, false
|
||||
},
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{{
|
||||
"post_status", "publish",
|
||||
}, {"ID", "in", ""}}),
|
||||
Order(SqlBuilder{{"ID", "desc"}}),
|
||||
In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...),
|
||||
),
|
||||
},
|
||||
wantR: func() []uint64 {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return slice.FilterAndMap(r, func(t post) (uint64, bool) {
|
||||
if t.Id <= 300 {
|
||||
return 0, false
|
||||
}
|
||||
return t.Id, true
|
||||
})
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := Chunk[post](tt.args.ctx, tt.args.perLimit, tt.args.fn, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Chunk() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("Chunk() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPagination(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
want []T
|
||||
want1 int
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase[post]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{
|
||||
{"ID", "in", ""},
|
||||
}),
|
||||
Page(1),
|
||||
Limit(5),
|
||||
In([][]any{slice.ToAnySlice(number.Range(431, 440, 1))}...),
|
||||
),
|
||||
},
|
||||
want: func() (r []post) {
|
||||
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
panic(err)
|
||||
} else if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}(),
|
||||
want1: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1, err := Pagination[post](tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Pagination() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Pagination() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Pagination() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumn(t *testing.T) {
|
||||
type args[V Model, T any] struct {
|
||||
ctx context.Context
|
||||
fn func(V) (T, bool)
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[V Model, T any] struct {
|
||||
name string
|
||||
args args[V, T]
|
||||
wantR []T
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase[post, uint64]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args[post, uint64]{
|
||||
ctx: ctx,
|
||||
fn: func(t post) (uint64, bool) {
|
||||
return t.Id, true
|
||||
},
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{
|
||||
{"ID", "<", "200", "int"},
|
||||
}),
|
||||
),
|
||||
},
|
||||
wantR: []uint64{63, 64, 190, 193},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := Column[post](tt.args.ctx, tt.args.fn, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Column() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("Column() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"`
|
||||
OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"`
|
||||
OptionValue string `gorm:"column:option_value" db:"option_value" json:"option_value" form:"option_value"`
|
||||
Autoload string `gorm:"column:autoload" db:"autoload" json:"autoload" form:"autoload"`
|
||||
}
|
||||
|
||||
func (w options) PrimaryKey() string {
|
||||
return "option_id"
|
||||
}
|
||||
|
||||
func (w options) Table() string {
|
||||
return "wp_options"
|
||||
}
|
||||
|
||||
func Test_getField(t *testing.T) {
|
||||
{
|
||||
name := "string"
|
||||
db := glob
|
||||
field := "option_value"
|
||||
q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}}))
|
||||
wantR := "记录并见证自己的成长"
|
||||
wantErr := false
|
||||
t.Run(name, func(t *testing.T) {
|
||||
gotR, err := getField[options](db, ctx, field, q)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, wantR) {
|
||||
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
{
|
||||
name := "t2"
|
||||
db := glob
|
||||
field := "option_id"
|
||||
q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}}))
|
||||
wantR := "3"
|
||||
wantErr := false
|
||||
t.Run(name, func(t *testing.T) {
|
||||
gotR, err := getField[options](db, ctx, field, q)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, wantR) {
|
||||
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
{
|
||||
name := "count(*)"
|
||||
db := glob
|
||||
field := "count(*)"
|
||||
q := Conditions()
|
||||
wantR := "386"
|
||||
wantErr := false
|
||||
t.Run(name, func(t *testing.T) {
|
||||
gotR, err := getField[options](db, ctx, field, q)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("getField() error = %v, wantErr %v", err, wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, wantR) {
|
||||
t.Errorf("getField() gotR = %v, want %v", gotR, wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getToStringMap(t *testing.T) {
|
||||
type args struct {
|
||||
db dbQuery
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantR map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(Where(SqlBuilder{{"option_name", "users_can_register"}})),
|
||||
},
|
||||
wantR: map[string]string{
|
||||
"option_id": "5",
|
||||
"option_value": "0",
|
||||
"option_name": "users_can_register",
|
||||
"autoload": "yes",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "t2",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{{"option_name", "users_can_register"}}),
|
||||
Fields("option_id id"),
|
||||
),
|
||||
},
|
||||
wantR: map[string]string{
|
||||
"id": "5",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := getToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("getToStringMap() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("getToStringMap() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_findToStringMap(t *testing.T) {
|
||||
type args struct {
|
||||
db dbQuery
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantR []map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(Where(SqlBuilder{{"option_id", "5"}})),
|
||||
},
|
||||
wantR: []map[string]string{{
|
||||
"option_id": "5",
|
||||
"option_value": "0",
|
||||
"option_name": "users_can_register",
|
||||
"autoload": "yes",
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "t2",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{{"option_id", "5"}}),
|
||||
Fields("option_value,option_name"),
|
||||
),
|
||||
},
|
||||
wantR: []map[string]string{{
|
||||
"option_value": "0",
|
||||
"option_name": "users_can_register",
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "t3",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(
|
||||
Where(SqlBuilder{{"option_id", "5"}}),
|
||||
Fields("option_value v,option_name k"),
|
||||
),
|
||||
},
|
||||
wantR: []map[string]string{{
|
||||
"v": "0",
|
||||
"k": "users_can_register",
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := findToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("findToStringMap() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("findToStringMap() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_findScanner(t *testing.T) {
|
||||
type args[T Model] struct {
|
||||
db dbQuery
|
||||
ctx context.Context
|
||||
fn func(T)
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args[T]
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase[options]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args[options]{glob, ctx, func(t options) {
|
||||
fmt.Println(t)
|
||||
}, Conditions(Where(SqlBuilder{{"option_id", "<", "10", "int"}}))},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := findScanner[options](tt.args.db, tt.args.ctx, tt.args.fn, tt.args.q); (err != nil) != tt.wantErr {
|
||||
t.Errorf("findScanner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScannerXX(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := findScanner[options](glob, ctx, func(t options) {
|
||||
_ = t
|
||||
//fmt.Println(t)
|
||||
}, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}})))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFindsXX(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, err := finds[options](glob, ctx, Conditions(Where(SqlBuilder{{"option_id", "<", "100", "int"}})))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, o := range r {
|
||||
_ = o
|
||||
//fmt.Println(o)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_gets(t *testing.T) {
|
||||
type args struct {
|
||||
db dbQuery
|
||||
ctx context.Context
|
||||
q *QueryCondition
|
||||
}
|
||||
type testCase[T Model] struct {
|
||||
name string
|
||||
args args
|
||||
wantR T
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase[options]{
|
||||
{
|
||||
name: "t1",
|
||||
args: args{
|
||||
db: glob,
|
||||
ctx: ctx,
|
||||
q: Conditions(Where(SqlBuilder{{"option_name", "blogname"}})),
|
||||
},
|
||||
wantR: options{3, "blogname", "记录并见证自己的成长", "yes"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotR, err := gets[options](tt.args.db, tt.args.ctx, tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("gets() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("gets() gotR = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,14 +9,14 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
type SqlxQuery struct {
|
||||
type SqlxQuery[T any] struct {
|
||||
sqlx *sqlx.DB
|
||||
UniversalDb
|
||||
UniversalDb[T]
|
||||
}
|
||||
|
||||
func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery {
|
||||
func NewSqlxQuery[T any](sqlx *sqlx.DB, u UniversalDb[T]) *SqlxQuery[T] {
|
||||
|
||||
s := &SqlxQuery{sqlx: sqlx, UniversalDb: u}
|
||||
s := &SqlxQuery[T]{sqlx: sqlx, UniversalDb: u}
|
||||
if u.selects == nil {
|
||||
s.UniversalDb.selects = s.Selects
|
||||
}
|
||||
|
@ -26,52 +26,56 @@ func NewSqlxQuery(sqlx *sqlx.DB, u UniversalDb) *SqlxQuery {
|
|||
return s
|
||||
}
|
||||
|
||||
func SetSelect(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) {
|
||||
func SetSelect[T any](db *SqlxQuery[T], fn QuerySelect[T]) {
|
||||
db.selects = fn
|
||||
}
|
||||
func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) {
|
||||
func SetGet[T any](db *SqlxQuery[T], fn QueryGet[T]) {
|
||||
db.gets = fn
|
||||
}
|
||||
|
||||
func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error {
|
||||
func (s *SqlxQuery[T]) Selects(ctx context.Context, sql string, params ...any) (r []T, err error) {
|
||||
v := ctx.Value("handle=>")
|
||||
if v != nil {
|
||||
vv, ok := v.(string)
|
||||
if ok && vv != "" {
|
||||
switch vv {
|
||||
case "string":
|
||||
return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...)
|
||||
//return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...)
|
||||
case "scanner":
|
||||
fn := ctx.Value("fn")
|
||||
return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any)))
|
||||
return nil, Scanner[T](s.sqlx, sql, params...)(fn.(func(T)))
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.sqlx.Select(dest, sql, params...)
|
||||
//var a T
|
||||
err = s.sqlx.Select(&r, sql, params...)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error {
|
||||
func (s *SqlxQuery[T]) Gets(ctx context.Context, sql string, params ...any) (r T, err error) {
|
||||
v := ctx.Value("handle=>")
|
||||
if v != nil {
|
||||
vv, ok := v.(string)
|
||||
if ok && vv != "" {
|
||||
switch vv {
|
||||
case "string":
|
||||
return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...)
|
||||
//return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...)
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.sqlx.Get(dest, sql, params...)
|
||||
err = s.sqlx.Get(&r, sql, params...)
|
||||
return
|
||||
}
|
||||
|
||||
func Scanner[T any](db *sqlx.DB, v T, s string, params ...any) func(func(T)) error {
|
||||
func Scanner[T any](db *sqlx.DB, s string, params ...any) func(func(T)) error {
|
||||
var v T
|
||||
return func(fn func(T)) error {
|
||||
rows, err := db.Queryx(s, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.StructScan(v)
|
||||
err = rows.StructScan(&v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,19 +2,15 @@ package model
|
|||
|
||||
import "context"
|
||||
|
||||
type UniversalDb struct {
|
||||
selects QueryFn
|
||||
gets QueryFn
|
||||
type UniversalDb[T any] struct {
|
||||
selects QuerySelect[T]
|
||||
gets QueryGet[T]
|
||||
}
|
||||
|
||||
func NewUniversalDb(selects QueryFn, gets QueryFn) UniversalDb {
|
||||
return UniversalDb{selects: selects, gets: gets}
|
||||
func (u *UniversalDb[T]) Select(ctx context.Context, s string, a ...any) ([]T, error) {
|
||||
return u.selects(ctx, s, a...)
|
||||
}
|
||||
|
||||
func (u UniversalDb) Select(ctx context.Context, a any, s string, args ...any) error {
|
||||
return u.selects(ctx, a, s, args...)
|
||||
}
|
||||
|
||||
func (u UniversalDb) Get(ctx context.Context, a any, s string, args ...any) error {
|
||||
return u.gets(ctx, a, s, args...)
|
||||
func (u *UniversalDb[T]) Get(ctx context.Context, s string, a ...any) (T, error) {
|
||||
return u.gets(ctx, s, a...)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user