diff --git a/model/condition.go b/model/condition.go new file mode 100644 index 0000000..da4ef85 --- /dev/null +++ b/model/condition.go @@ -0,0 +1,86 @@ +package model + +type QueryCondition struct { + where ParseWhere + fields string + group string + order SqlBuilder + join SqlBuilder + having SqlBuilder + page int + limit int + offset int + in [][]any +} + +func Conditions(fns ...Condition) *QueryCondition { + r := &QueryCondition{} + for _, fn := range fns { + fn(r) + } + if r.fields == "" { + r.fields = "*" + } + return r +} + +type Condition func(c *QueryCondition) + +func Where(where ParseWhere) Condition { + return func(c *QueryCondition) { + c.where = where + } +} +func Fields(fields string) Condition { + return func(c *QueryCondition) { + c.fields = fields + } +} + +func Group(group string) Condition { + return func(c *QueryCondition) { + c.group = group + } +} + +func Order(order SqlBuilder) Condition { + return func(c *QueryCondition) { + c.order = order + } +} + +func Join(join SqlBuilder) Condition { + return func(c *QueryCondition) { + c.join = join + } +} + +func Having(having SqlBuilder) Condition { + return func(c *QueryCondition) { + c.having = having + } +} + +func Page(page int) Condition { + return func(c *QueryCondition) { + c.page = page + } +} + +func Limit(limit int) Condition { + return func(c *QueryCondition) { + c.limit = limit + } +} + +func Offset(offset int) Condition { + return func(c *QueryCondition) { + c.offset = offset + } +} + +func In(in ...[]any) Condition { + return func(c *QueryCondition) { + c.in = append(c.in, in...) + } +} diff --git a/model/query.go b/model/query.go index f994166..596faa4 100644 --- a/model/query.go +++ b/model/query.go @@ -79,8 +79,8 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr return } tp := "select %s from %s %s %s %s %s %s limit %d,%d" - sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) - err = globalBb.Select(ctx, &r, sql, args...) + sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) + err = globalBb.Select(ctx, &r, sq, args...) if err != nil { return } @@ -89,8 +89,8 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr func FindOneById[T Model, I number.IntNumber](ctx context.Context, id I) (T, error) { var r T - sql := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) - err := globalBb.Get(ctx, &r, sql, id) + sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) + err := globalBb.Get(ctx, &r, sq, id) if err != nil { return r, err } @@ -109,8 +109,8 @@ func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, ord } } tp := "select %s from %s %s %s" - sql := fmt.Sprintf(tp, fields, r.Table(), w, order.parseOrderBy()) - err = globalBb.Get(ctx, &r, sql, args...) + sq := fmt.Sprintf(tp, fields, r.Table(), w, order.parseOrderBy()) + err = globalBb.Get(ctx, &r, sq, args...) if err != nil { return r, err } @@ -129,8 +129,8 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in . } } tp := "select %s from %s %s order by %s desc limit 1" - sql := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) - err = globalBb.Get(ctx, &r, sql, args...) + sq := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) + err = globalBb.Get(ctx, &r, sq, args...) if err != nil { return r, err } @@ -150,8 +150,8 @@ func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, i } } tp := "select %s from %s %s" - sql := fmt.Sprintf(tp, fields, rr.Table(), w) - err = globalBb.Select(ctx, &r, sql, args...) + sq := fmt.Sprintf(tp, fields, rr.Table(), w) + err = globalBb.Select(ctx, &r, sq, args...) if err != nil { return r, err } @@ -202,8 +202,8 @@ func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, if limit > 0 { l = fmt.Sprintf(" limit %d", limit) } - sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), l) - err = globalBb.Select(ctx, &r, sql, args...) + sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), l) + err = globalBb.Select(ctx, &r, sq, args...) return } diff --git a/model/query_test.go b/model/query_test.go index 0674c42..8e335a1 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -3,11 +3,15 @@ package model import ( "context" "database/sql" + "fmt" "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" + "log" "reflect" + "strconv" + "strings" "testing" "time" ) @@ -104,14 +108,34 @@ type SqlxDb struct { var Db *SqlxDb -func (r SqlxDb) Select(ctx context.Context, dest any, sql string, params ...any) error { +func (r SqlxDb) Select(_ context.Context, dest any, sql string, params ...any) error { + log.Println(formatSql(sql, params)) return r.sqlx.Select(dest, sql, params...) } -func (r SqlxDb) Get(ctx context.Context, dest any, sql string, params ...any) error { +func (r SqlxDb) Get(_ context.Context, dest any, sql string, params ...any) error { + log.Println(formatSql(sql, params)) return r.sqlx.Get(dest, sql, params...) } +func formatSql(sql string, params []any) string { + for _, param := range params { + switch param.(type) { + case string: + sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) + case int64: + sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) + case int: + sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) + case uint64: + sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) + case float64: + sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) + } + } + return sql +} + var ctx = context.Background() func init() { diff --git a/model/querycondition.go b/model/querycondition.go new file mode 100644 index 0000000..1976754 --- /dev/null +++ b/model/querycondition.go @@ -0,0 +1,134 @@ +package model + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +// Finds can use offset +// +// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 +func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { + var rr T + w := "" + var args []any + if q.where != nil { + w, args, err = q.where.ParseWhere(&q.in) + if err != nil { + return r, err + } + } + h := "" + if q.having != nil { + hh, arg, err := q.having.ParseWhere(&q.in) + if err != nil { + return r, err + } + args = append(args, arg...) + h = strings.Replace(hh, " where", " having", 1) + } + + j := q.join.parseJoin() + groupBy := "" + if q.group != "" { + g := strings.Builder{} + g.WriteString(" group by ") + g.WriteString(q.group) + groupBy = g.String() + } + tp := "select %s from %s %s %s %s %s %s %s" + l := "" + + if q.limit > 0 { + l = fmt.Sprintf(" limit %d", q.limit) + } + if q.offset > 0 { + l = fmt.Sprintf(" %s offset %d", l, q.offset) + } + sq := fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) + err = globalBb.Select(ctx, &r, sq, args...) + return +} + +// ChunkFind 分片查询并直接返回所有结果 +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { + i := 1 + var rr []T + var total int + var offset int + for { + if 1 == i { + rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + } else { + rr, err = Finds[T](ctx, Conditions( + Where(q.where), + Fields(q.fields), + Group(q.group), + Having(q.having), + Join(q.join), + Order(q.order), + Offset(offset), + Limit(perLimit), + In(q.in...), + )) + } + offset += perLimit + if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { + return + } + r = append(r, rr...) + if len(r) >= total { + break + } + i++ + } + return +} + +// Chunk 分片查询并函数过虑返回新类型的切片 +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { + i := 1 + var rr []T + var count int + var total int + var offset int + for { + if 1 == i { + rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + } else { + rr, err = Finds[T](ctx, Conditions( + Where(q.where), + Fields(q.fields), + Group(q.group), + Having(q.having), + Join(q.join), + Order(q.order), + Offset(offset), + Limit(perLimit), + In(q.in...), + )) + } + offset += perLimit + if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { + return + } + for _, t := range rr { + v, ok := fn(t) + if ok { + r = append(r, v) + } + } + count += len(rr) + if count >= total { + break + } + i++ + } + return +} diff --git a/model/querycondition_test.go b/model/querycondition_test.go new file mode 100644 index 0000000..2cabfb0 --- /dev/null +++ b/model/querycondition_test.go @@ -0,0 +1,175 @@ +package model + +import ( + "context" + "github.com/fthvgb1/wp-go/helper/number" + "github.com/fthvgb1/wp-go/helper/slice" + "reflect" + "strconv" + "strings" + "testing" +) + +func TestFinds(t *testing.T) { + type args struct { + ctx context.Context + q *QueryCondition + } + type testCase[T Model] struct { + name string + args args + wantR []T + wantErr bool + } + tests := []testCase[post]{ + { + name: "t1", + args: args{ + ctx: context.Background(), + q: Conditions( + Where(SqlBuilder{ + {"post_status", "publish"}, {"ID", "in", ""}}, + ), + Order(SqlBuilder{{"ID", "desc"}}), + Offset(10), + Limit(10), + In([][]any{slice.ToAnySlice(number.Range(1, 1000, 1))}...), + ), + }, + wantR: func() []post { + r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, 1000, 1), strconv.Itoa), ",")+") order by ID desc limit 10 offset 10 ") + if err != nil { + panic(err) + } + return r + }(), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := Finds[post](tt.args.ctx, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("Findx() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("Findx() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +} + +func TestChunkFind(t *testing.T) { + type args struct { + ctx context.Context + perLimit int + q *QueryCondition + } + type testCase[T Model] struct { + name string + args args + wantR []T + wantErr bool + } + n := 500 + tests := []testCase[post]{ + { + name: "in,orderBy", + args: args{ + ctx: ctx, + q: Conditions( + Where(SqlBuilder{{ + "post_status", "publish", + }, {"ID", "in", ""}}), + Order(SqlBuilder{{"ID", "desc"}}), + In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...), + ), + perLimit: 20, + }, + wantR: func() []post { + r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc") + if err != nil { + panic(err) + } + return r + }(), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := ChunkFind[post](tt.args.ctx, tt.args.perLimit, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("ChunkFind() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("ChunkFind() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +} + +func TestChunk(t *testing.T) { + type args[T Model, R any] struct { + ctx context.Context + perLimit int + fn func(rows T) (R, bool) + q *QueryCondition + } + type testCase[T Model, R any] struct { + name string + args args[T, R] + wantR []R + wantErr bool + } + n := 500 + tests := []testCase[post, uint64]{ + { + name: "t1", + args: args[post, uint64]{ + ctx: ctx, + perLimit: 20, + fn: func(t post) (uint64, bool) { + if t.Id > 300 { + return t.Id, true + } + return 0, false + }, + q: Conditions( + Where(SqlBuilder{{ + "post_status", "publish", + }, {"ID", "in", ""}}), + Order(SqlBuilder{{"ID", "desc"}}), + In([][]any{slice.ToAnySlice(number.Range(1, n, 1))}...), + ), + }, + wantR: func() []uint64 { + r, err := Select[post](ctx, "select * from "+post{}.Table()+" where post_status='publish' and ID in ("+strings.Join(slice.Map(number.Range(1, n, 1), strconv.Itoa), ",")+") order by ID desc") + if err != nil { + panic(err) + } + return slice.FilterAndMap(r, func(t post) (uint64, bool) { + if t.Id <= 300 { + return 0, false + } + return t.Id, true + }) + }(), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := Chunk[post](tt.args.ctx, tt.args.perLimit, tt.args.fn, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("Chunk() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("Chunk() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +}