优化数据库查询相关

This commit is contained in:
xing 2023-02-25 23:10:42 +08:00
parent 00e42c2d56
commit e0786f7f8b
8 changed files with 202 additions and 233 deletions

View File

@ -91,18 +91,9 @@ func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) {
func SearchPostIds(args ...any) (ids PostIds, err error) { func SearchPostIds(args ...any) (ids PostIds, err error) {
ctx := args[0].(context.Context) ctx := args[0].(context.Context)
where := args[1].(model.SqlBuilder) q := args[1].(model.QueryCondition)
page := args[2].(int) q.Fields = "ID"
limit := args[3].(int) res, total, err := model.Pagination[models.Posts](ctx, q)
order := args[4].(model.SqlBuilder)
join := args[5].(model.SqlBuilder)
postType := args[6].([]any)
postStatus := args[7].([]any)
res, total, err := model.SimplePagination[models.Posts](
ctx, where, "ID",
"", page, limit, order,
join, nil, postType, postStatus,
)
for _, posts := range res { for _, posts := range res {
ids.Ids = append(ids.Ids, posts.Id) ids.Ids = append(ids.Ids, posts.Id)
} }

View File

@ -56,16 +56,22 @@ func (i *IndexHandle) ParseIndex(parm *IndexParams) (err error) {
func (i *IndexHandle) GetIndexData() (posts []models.Posts, totalRaw int, err error) { func (i *IndexHandle) GetIndexData() (posts []models.Posts, totalRaw int, err error) {
q := model.QueryCondition{
Where: i.Param.Where,
Page: i.Param.Page,
Limit: i.Param.PageSize,
Order: model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}},
Join: i.Param.Join,
In: [][]any{i.Param.PostType, i.Param.PostStatus},
}
switch i.Scene { switch i.Scene {
case constraints.Home, constraints.Category, constraints.Tag, constraints.Author: case constraints.Home, constraints.Category, constraints.Tag, constraints.Author:
posts, totalRaw, err = cache.PostLists(i.C, i.Param.CacheKey, i.C, i.Param.Where, i.Param.Page, i.Param.PageSize, posts, totalRaw, err = cache.PostLists(i.C, i.Param.CacheKey, i.C, q)
model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, i.Param.Join, i.Param.PostType, i.Param.PostStatus)
case constraints.Search: case constraints.Search:
posts, totalRaw, err = cache.SearchPost(i.C, i.Param.CacheKey, i.C, i.Param.Where, i.Param.Page, i.Param.PageSize, posts, totalRaw, err = cache.SearchPost(i.C, i.Param.CacheKey, i.C, q)
model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, i.Param.Join, i.Param.PostType, i.Param.PostStatus)
case constraints.Archive: case constraints.Archive:

View File

@ -1,26 +1,26 @@
package model package model
type QueryCondition struct { type QueryCondition struct {
where ParseWhere Where ParseWhere
from string From string
fields string Fields string
group string Group string
order SqlBuilder Order SqlBuilder
join SqlBuilder Join SqlBuilder
having SqlBuilder Having SqlBuilder
page int Page int
limit int Limit int
offset int Offset int
in [][]any In [][]any
} }
func Conditions(fns ...Condition) *QueryCondition { func Conditions(fns ...Condition) QueryCondition {
r := &QueryCondition{} r := QueryCondition{}
for _, fn := range fns { for _, fn := range fns {
fn(r) fn(&r)
} }
if r.fields == "" { if r.Fields == "" {
r.fields = "*" r.Fields = "*"
} }
return r return r
} }
@ -29,65 +29,65 @@ type Condition func(c *QueryCondition)
func Where(where ParseWhere) Condition { func Where(where ParseWhere) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.where = where c.Where = where
} }
} }
func Fields(fields string) Condition { func Fields(fields string) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.fields = fields c.Fields = fields
} }
} }
func From(from string) Condition { func From(from string) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.from = from c.From = from
} }
} }
func Group(group string) Condition { func Group(group string) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.group = group c.Group = group
} }
} }
func Order(order SqlBuilder) Condition { func Order(order SqlBuilder) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.order = order c.Order = order
} }
} }
func Join(join SqlBuilder) Condition { func Join(join SqlBuilder) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.join = join c.Join = join
} }
} }
func Having(having SqlBuilder) Condition { func Having(having SqlBuilder) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.having = having c.Having = having
} }
} }
func Page(page int) Condition { func Page(page int) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.page = page c.Page = page
} }
} }
func Limit(limit int) Condition { func Limit(limit int) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.limit = limit c.Limit = limit
} }
} }
func Offset(offset int) Condition { func Offset(offset int) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.offset = offset c.Offset = offset
} }
} }
func In(in ...[]any) Condition { func In(in ...[]any) Condition {
return func(c *QueryCondition) { return func(c *QueryCondition) {
c.in = append(c.in, in...) c.In = append(c.In, in...)
} }
} }

View File

@ -56,9 +56,11 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error {
// ParseWhere 解析为where条件支持3种风格,具体用法参照query_test中的 Find 的测试方法 // ParseWhere 解析为where条件支持3种风格,具体用法参照query_test中的 Find 的测试方法
// //
// 1. 2个为一组 {{"field1","value1"},{"field2","value2"}} => where field1='value1' and field2='value2' // 1. 1个为一组 {{"field operator value"}} 为纯字符串条件,不对参数做处理
// //
// 2. 3个或4个为一组 {{"field","operator","value"[,"int|float"]}} => where field operator 'string'|int|float // 2. 2个为一组 {{"field1","value1"},{"field2","value2"}} => where field1='value1' and field2='value2'
//
// 3. 3个或4个为一组 {{"field","operator","value"[,"int|float"]}} => where field operator 'string'|int|float
// //
// {{"a",">","1","int"}} => where 'a'> 1 // {{"a",">","1","int"}} => where 'a'> 1
// //
@ -66,7 +68,7 @@ func (w SqlBuilder) parseType(ss []string, args *[]any) error {
// //
// 另外如果是操作符为in的话为 {{"field","in",""}} => where field in (?,..) in的条件传给 in参数 // 另外如果是操作符为in的话为 {{"field","in",""}} => where field in (?,..) in的条件传给 in参数
// //
// 3. 5的倍数为一组{{"and|or","field","operator","value","int|float"}}会忽然掉第一组的and|or // 4. 5的倍数为一组{{"and|or","field","operator","value","int|float"}}会忽然掉第一组的and|or
// //
// {{"and","field","=","value1","","and","field","=","value2",""}} => where (field = 'value1' and field = 'value2') // {{"and","field","=","value1","","and","field","=","value2",""}} => where (field = 'value1' and field = 'value2')
// //

View File

@ -3,81 +3,54 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/fthvgb1/wp-go/helper/number"
str "github.com/fthvgb1/wp-go/helper/strings"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
"math/rand" "math/rand"
"strings" "strings"
) )
func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, total int, err error) {
var rr T qx := QueryCondition{
var w string Where: q.Where,
var args []any Having: q.Having,
if where != nil { Join: q.Join,
w, args, err = where.ParseWhere(&in) In: q.In,
if err != nil { Group: q.Group,
return r, total, err From: q.From,
}
if q.Group != "" {
qx.Fields = q.Fields
sq, in, er := BuildQuerySql[T](qx)
qx.In = [][]any{in}
if er != nil {
err = er
return
}
qx.From = str.Join("( ", sq, " ) ", "table", number.ToString(rand.Int()))
qx = QueryCondition{
From: qx.From,
In: qx.In,
} }
} }
h := "" n, err := GetField[T](ctx, "count(*)", qx)
if having != nil { total = str.ToInt[int](n)
hh, arg, err := having.ParseWhere(&in) if err != nil || total < 1 {
if err != nil {
return r, total, err
}
args = append(args, arg...)
h = strings.Replace(hh, " where", " having", 1)
}
n := struct {
N int `db:"n" json:"n"`
}{}
groupBy := ""
if group != "" {
g := strings.Builder{}
g.WriteString(" group by ")
g.WriteString(group)
groupBy = g.String()
}
if having != nil {
tm := map[string]struct{}{}
for _, s := range strings.Split(group, ",") {
tm[s] = struct{}{}
}
for _, ss := range having {
if _, ok := tm[ss[0]]; !ok {
group = fmt.Sprintf("%s,%s", group, ss[0])
}
}
group = strings.Trim(group, ",")
}
j := join.parseJoin()
if group == "" {
tpx := "select count(*) n from %s %s %s limit 1"
sq := fmt.Sprintf(tpx, rr.Table(), j, w)
err = db.Get(ctx, &n, sq, args...)
} else {
tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s"
sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int()))
err = db.Get(ctx, &n, sq, args...)
}
if err != nil {
return return
} }
if n.N == 0 {
return
}
total = n.N
offset := 0 offset := 0
if page > 1 { if q.Page > 1 {
offset = (page - 1) * pageSize offset = (q.Page - 1) * q.Limit
} }
if offset >= total { if offset >= total {
return return
} }
tp := "select %s from %s %s %s %s %s %s limit %d,%d" q.Offset = offset
sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) sq, args, err := BuildQuerySql[T](q)
if err != nil {
return
}
err = db.Select(ctx, &r, sq, args...) err = db.Select(ctx, &r, sq, args...)
if err != nil { if err != nil {
return return
@ -85,11 +58,6 @@ func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fiel
return return
} }
func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) {
r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...)
return
}
func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) {
var r T var r T
sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey())
@ -101,12 +69,12 @@ func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T,
} }
func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) { func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) {
s, args, err := BuildQuerySql[T](&QueryCondition{ s, args, err := BuildQuerySql[T](QueryCondition{
where: where, Where: where,
fields: fields, Fields: fields,
order: order, Order: order,
in: in, In: in,
limit: 1, Limit: 1,
}) })
if err != nil { if err != nil {
return return
@ -136,10 +104,10 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in .
} }
func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) {
s, args, err := BuildQuerySql[T](&QueryCondition{ s, args, err := BuildQuerySql[T](QueryCondition{
where: where, Where: where,
fields: fields, Fields: fields,
in: in, In: in,
}) })
if err != nil { if err != nil {
return return
@ -161,16 +129,16 @@ func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error
func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) {
q := QueryCondition{ q := QueryCondition{
where: where, Where: where,
fields: fields, Fields: fields,
group: group, Group: group,
order: order, Order: order,
join: join, Join: join,
having: having, Having: having,
limit: limit, Limit: limit,
in: in, In: in,
} }
s, args, err := BuildQuerySql[T](&q) s, args, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return return
} }

View File

@ -3,8 +3,7 @@ package model
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/fthvgb1/wp-go/helper/number" "fmt"
"github.com/fthvgb1/wp-go/helper/slice"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"reflect" "reflect"
@ -108,7 +107,15 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) glob = NewSqlxQuery(db, NewUniversalDb(func(ctx2 context.Context, a any, s string, a2 ...any) error {
x := FormatSql(s, a2...)
fmt.Println(x)
return glob.Selects(ctx2, a, s, a2...)
}, func(ctx2 context.Context, a any, s string, a2 ...any) error {
x := FormatSql(s, a2...)
fmt.Println(x)
return glob.Gets(ctx2, a, s, a2...)
}))
ddb = db ddb = db
InitDB(glob) InitDB(glob)
} }
@ -456,65 +463,55 @@ func TestSimpleFind(t *testing.T) {
} }
} }
func TestSimplePagination(t *testing.T) { func Test_pagination(t *testing.T) {
type args struct { type args struct {
where ParseWhere db dbQuery
fields string ctx context.Context
group string q QueryCondition
page int
pageSize int
order SqlBuilder
join SqlBuilder
having SqlBuilder
in [][]any
} }
tests := []struct { type testCase[T Model] struct {
name string name string
args args args args
wantR []post wantR []T
wantTotal int wantTotal int
wantErr bool wantErr bool
}{ }
tests := []testCase[post]{
{ {
name: "t1", name: "t1",
args: args{ args: args{
where: SqlBuilder{ db: glob,
{"ID", "in", ""}, ctx: ctx,
q: QueryCondition{
Fields: "post_type,count(*) ID",
Group: "post_type",
Having: SqlBuilder{{"ID", ">", "1", "int"}},
}, },
fields: "*",
group: "",
page: 1,
pageSize: 5,
order: nil,
join: nil,
having: nil,
in: [][]any{slice.ToAnySlice(number.Range(431, 440, 1))},
}, },
wantR: func() (r []post) { wantR: func() (r []post) {
r, err := Select[post](ctx, "select * from "+post{}.Table()+" where ID in (?,?,?,?,?)", slice.ToAnySlice(number.Range(431, 435, 1))...)
if err != nil && err != sql.ErrNoRows { err := glob.Selects(ctx, &r, "select post_type,count(*) ID from wp_posts group by post_type having `ID`> 1")
if err != nil {
panic(err) panic(err)
} else if err == sql.ErrNoRows {
err = nil
} }
return return r
}(), }(),
wantTotal: 10, wantTotal: 7,
wantErr: false, wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotR, gotTotal, err := SimplePagination[post](ctx, tt.args.where, tt.args.fields, tt.args.group, tt.args.page, tt.args.pageSize, tt.args.order, tt.args.join, tt.args.having, tt.args.in...) gotR, gotTotal, err := pagination[post](tt.args.db, tt.args.ctx, tt.args.q)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("SimplePagination() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("pagination() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(gotR, tt.wantR) { if !reflect.DeepEqual(gotR, tt.wantR) {
t.Errorf("SimplePagination() gotR = %v, want %v", gotR, tt.wantR) t.Errorf("pagination() gotR = %v, want %v", gotR, tt.wantR)
} }
if gotTotal != tt.wantTotal { if gotTotal != tt.wantTotal {
t.Errorf("SimplePagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal) t.Errorf("pagination() gotTotal = %v, want %v", gotTotal, tt.wantTotal)
} }
}) })
} }

View File

@ -12,7 +12,7 @@ import (
// Finds 比 Find 多一个offset // Finds 比 Find 多一个offset
// //
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) {
r, err = finds[T](globalBb, ctx, q) r, err = finds[T](globalBb, ctx, q)
return return
} }
@ -20,12 +20,12 @@ func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
// FindFromDB 同 Finds 使用指定 db 查询 // FindFromDB 同 Finds 使用指定 db 查询
// //
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { func FindFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) {
r, err = finds[T](db, ctx, q) r, err = finds[T](db, ctx, q)
return return
} }
func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) {
sq, args, err := BuildQuerySql[T](q) sq, args, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return return
@ -34,17 +34,17 @@ func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T,
return return
} }
func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
i := 1 i := 1
var rr []T var rr []T
var total int var total int
var offset int var offset int
for { for {
if 1 == i { if 1 == i {
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) rr, total, err = pagination[T](db, ctx, q)
} else { } else {
q.offset = offset q.Offset = offset
q.limit = perLimit q.Limit = perLimit
rr, err = finds[T](db, ctx, q) rr, err = finds[T](db, ctx, q)
} }
offset += perLimit offset += perLimit
@ -63,7 +63,7 @@ func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryC
// ChunkFind 分片查询并直接返回所有结果 // ChunkFind 分片查询并直接返回所有结果
// //
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
r, err = chunkFind[T](globalBb, ctx, perLimit, q) r, err = chunkFind[T](globalBb, ctx, perLimit, q)
return return
} }
@ -71,7 +71,7 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r
// ChunkFindFromDB 同 ChunkFind // ChunkFindFromDB 同 ChunkFind
// //
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
r, err = chunkFind[T](db, ctx, perLimit, q) r, err = chunkFind[T](db, ctx, perLimit, q)
return return
} }
@ -79,7 +79,7 @@ func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *
// Chunk 分片查询并函数过虑返回新类型的切片 // Chunk 分片查询并函数过虑返回新类型的切片
// //
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
r, err = chunk(globalBb, ctx, perLimit, fn, q) r, err = chunk(globalBb, ctx, perLimit, fn, q)
return return
} }
@ -87,12 +87,12 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R
// ChunkFromDB 同 Chunk // ChunkFromDB 同 Chunk
// //
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
r, err = chunk(db, ctx, perLimit, fn, q) r, err = chunk(db, ctx, perLimit, fn, q)
return return
} }
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
i := 1 i := 1
var rr []T var rr []T
var count int var count int
@ -100,10 +100,10 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun
var offset int var offset int
for { for {
if 1 == i { if 1 == i {
rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) rr, total, err = pagination[T](db, ctx, q)
} else { } else {
q.offset = offset q.Offset = offset
q.limit = perLimit q.Limit = perLimit
rr, err = finds[T](db, ctx, q) rr, err = finds[T](db, ctx, q)
} }
offset += perLimit offset += perLimit
@ -125,28 +125,28 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun
return return
} }
// Pagination 同 SimplePagination // Pagination 同
// //
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 // Condition 中可使用 Where Fields From Group Having Join Order Page Limit In 函数
func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { func Pagination[T Model](ctx context.Context, q QueryCondition) ([]T, int, error) {
return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) return pagination[T](globalBb, ctx, q)
} }
// PaginationFromDB 同 Pagination 方便多个db使用 // PaginationFromDB 同 Pagination 方便多个db使用
// //
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 // Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) { func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) ([]T, int, error) {
return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) return pagination[T](db, ctx, q)
} }
func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) { func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q QueryCondition) ([]T, error) {
return column[V, T](globalBb, ctx, fn, q) return column[V, T](globalBb, ctx, fn, q)
} }
func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) {
return column[V, T](db, ctx, fn, q) return column[V, T](db, ctx, fn, q)
} }
func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) {
res, err := finds[V](db, ctx, q) res, err := finds[V](db, ctx, q)
if err != nil { if err != nil {
return nil, err return nil, err
@ -155,13 +155,13 @@ func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool
return return
} }
func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) { func GetField[T Model](ctx context.Context, field string, q QueryCondition) (r string, err error) {
r, err = getField[T](globalBb, ctx, field, q) r, err = getField[T](globalBb, ctx, field, q)
return return
} }
func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) {
if q.fields == "" || q.fields == "*" { if q.Fields == "" || q.Fields == "*" {
q.fields = field q.Fields = field
} }
res, err := getToStringMap[T](db, ctx, q) res, err := getToStringMap[T](db, ctx, q)
if err != nil { if err != nil {
@ -174,11 +174,11 @@ func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCo
} }
return return
} }
func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) {
return getField[T](db, ctx, field, q) return getField[T](db, ctx, field, q)
} }
func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q) rawSql, in, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return nil, err return nil, err
@ -187,12 +187,12 @@ func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition)
err = db.Get(ctx, &r, rawSql, in...) err = db.Get(ctx, &r, rawSql, in...)
return return
} }
func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) { func GetToStringMap[T Model](ctx context.Context, q QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](globalBb, ctx, q) r, err = getToStringMap[T](globalBb, ctx, q)
return return
} }
func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q) rawSql, in, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return nil, err return nil, err
@ -202,33 +202,33 @@ func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition
return return
} }
func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { func FindToStringMap[T Model](ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](globalBb, ctx, q) r, err = findToStringMap[T](globalBb, ctx, q)
return return
} }
func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](db, ctx, q) r, err = findToStringMap[T](db, ctx, q)
return return
} }
func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](db, ctx, q) r, err = getToStringMap[T](db, ctx, q)
return return
} }
func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) { func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error) {
var rr T var rr T
w := "" w := ""
if q.where != nil { if q.Where != nil {
w, args, err = q.where.ParseWhere(&q.in) w, args, err = q.Where.ParseWhere(&q.In)
if err != nil { if err != nil {
return return
} }
} }
h := "" h := ""
if q.having != nil { if q.Having != nil {
hh, arg, er := q.having.ParseWhere(&q.in) hh, arg, er := q.Having.ParseWhere(&q.In)
if er != nil { if er != nil {
err = er err = er
return return
@ -236,32 +236,37 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error)
args = append(args, arg...) args = append(args, arg...)
h = strings.Replace(hh, " where", " having", 1) h = strings.Replace(hh, " where", " having", 1)
} }
if len(args) == 0 && len(q.In) > 0 {
for _, antes := range q.In {
args = append(args, antes...)
}
}
j := q.join.parseJoin() j := q.Join.parseJoin()
groupBy := "" groupBy := ""
if q.group != "" { if q.Group != "" {
g := strings.Builder{} g := strings.Builder{}
g.WriteString(" group by ") g.WriteString(" group by ")
g.WriteString(q.group) g.WriteString(q.Group)
groupBy = g.String() groupBy = g.String()
} }
tp := "select %s from %s %s %s %s %s %s %s" tp := "select %s from %s %s %s %s %s %s %s"
l := "" l := ""
table := rr.Table() table := rr.Table()
if q.from != "" { if q.From != "" {
table = q.from table = q.From
} }
if q.limit > 0 { if q.Limit > 0 {
l = fmt.Sprintf(" limit %d", q.limit) l = fmt.Sprintf(" limit %d", q.Limit)
} }
if q.offset > 0 { if q.Offset > 0 {
l = fmt.Sprintf(" %s offset %d", l, q.offset) l = fmt.Sprintf(" %s offset %d", l, q.Offset)
} }
r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l) r = fmt.Sprintf(tp, q.Fields, table, j, w, groupBy, h, q.Order.parseOrderBy(), l)
return return
} }
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) { func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) {
s, args, err := BuildQuerySql[T](q) s, args, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return return
@ -275,22 +280,22 @@ func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryC
return return
} }
func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error { func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) error {
return findScanner[T](db, ctx, fn, q) return findScanner[T](db, ctx, fn, q)
} }
func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error { func FindScanner[T Model](ctx context.Context, fn func(T), q QueryCondition) error {
return findScanner[T](globalBb, ctx, fn, q) return findScanner[T](globalBb, ctx, fn, q)
} }
func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) { func Gets[T Model](ctx context.Context, q QueryCondition) (T, error) {
return gets[T](globalBb, ctx, q) return gets[T](globalBb, ctx, q)
} }
func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) { func GetsFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (T, error) {
return gets[T](db, ctx, q) return gets[T](db, ctx, q)
} }
func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) { func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) {
s, args, err := BuildQuerySql[T](q) s, args, err := BuildQuerySql[T](q)
if err != nil { if err != nil {
return return

View File

@ -15,7 +15,7 @@ import (
func TestFinds(t *testing.T) { func TestFinds(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string
@ -66,7 +66,7 @@ func TestChunkFind(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
perLimit int perLimit int
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string
@ -118,7 +118,7 @@ func TestChunk(t *testing.T) {
ctx context.Context ctx context.Context
perLimit int perLimit int
fn func(rows T) (R, bool) fn func(rows T) (R, bool)
q *QueryCondition q QueryCondition
} }
type testCase[T Model, R any] struct { type testCase[T Model, R any] struct {
name string name string
@ -179,7 +179,7 @@ func TestChunk(t *testing.T) {
func TestPagination(t *testing.T) { func TestPagination(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string
@ -236,7 +236,7 @@ func TestColumn(t *testing.T) {
type args[V Model, T any] struct { type args[V Model, T any] struct {
ctx context.Context ctx context.Context
fn func(V) (T, bool) fn func(V) (T, bool)
q *QueryCondition q QueryCondition
} }
type testCase[V Model, T any] struct { type testCase[V Model, T any] struct {
name string name string
@ -333,7 +333,7 @@ func Test_getField(t *testing.T) {
db := glob db := glob
field := "count(*)" field := "count(*)"
q := Conditions() q := Conditions()
wantR := "386" wantR := "387"
wantErr := false wantErr := false
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
gotR, err := getField[options](db, ctx, field, q) gotR, err := getField[options](db, ctx, field, q)
@ -352,7 +352,7 @@ func Test_getToStringMap(t *testing.T) {
type args struct { type args struct {
db dbQuery db dbQuery
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
tests := []struct { tests := []struct {
name string name string
@ -407,7 +407,7 @@ func Test_findToStringMap(t *testing.T) {
type args struct { type args struct {
db dbQuery db dbQuery
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
tests := []struct { tests := []struct {
name string name string
@ -482,7 +482,7 @@ func Test_findScanner(t *testing.T) {
db dbQuery db dbQuery
ctx context.Context ctx context.Context
fn func(T) fn func(T)
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string
@ -545,7 +545,7 @@ func Test_gets(t *testing.T) {
type args struct { type args struct {
db dbQuery db dbQuery
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string
@ -583,7 +583,7 @@ func Test_finds(t *testing.T) {
type args struct { type args struct {
db dbQuery db dbQuery
ctx context.Context ctx context.Context
q *QueryCondition q QueryCondition
} }
type testCase[T Model] struct { type testCase[T Model] struct {
name string name string