hasone hasmany 改

This commit is contained in:
xing 2023-05-18 22:27:28 +08:00
parent 044e55a399
commit 6caf07b575
8 changed files with 197 additions and 129 deletions

View File

@ -94,7 +94,7 @@ func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) {
func SearchPostIds(args ...any) (ids PostIds, err error) {
ctx := args[0].(context.Context)
q := args[1].(model.QueryCondition)
q := args[1].(*model.QueryCondition)
page := args[2].(int)
pageSize := args[3].(int)
q.Fields = "ID"

View File

@ -79,7 +79,7 @@ func (i *IndexHandle) ParseIndex(parm *IndexParams) (err error) {
func (i *IndexHandle) GetIndexData() (posts []models.Posts, totalRaw int, err error) {
q := model.QueryCondition{
q := &model.QueryCondition{
Where: i.Param.Where,
Order: model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}},
Join: i.Param.Join,

View File

@ -12,28 +12,19 @@ type QueryCondition struct {
Offset int
In [][]any
Relation map[string]*QueryCondition
WithJoin bool
}
func Conditions(fns ...Condition) QueryCondition {
r := QueryCondition{}
func Conditions(fns ...Condition) *QueryCondition {
r := &QueryCondition{}
for _, fn := range fns {
fn(&r)
fn(r)
}
if r.Fields == "" {
r.Fields = "*"
}
return r
}
func WithConditions(fns ...Condition) *QueryCondition {
r := QueryCondition{}
for _, fn := range fns {
fn(&r)
}
if r.Fields == "" {
r.Fields = "*"
}
return &r
}
type Condition func(c *QueryCondition)
@ -104,3 +95,9 @@ func With(tableTag string, q *QueryCondition) Condition {
c.Relation[tableTag] = q
}
}
func WithJoin(isJoin bool) Condition {
return func(c *QueryCondition) {
c.WithJoin = isJoin
}
}

View File

@ -23,7 +23,7 @@ func (c count[T]) Table() string {
return c.t.Table()
}
func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []T, total int, err error) {
func pagination[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []T, total int, err error) {
if page < 1 || pageSize < 1 {
return
}
@ -42,7 +42,7 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page
if qx.From == "" {
qx.From = Table[T]()
}
sq, in, er := BuildQuerySql(qx)
sq, in, er := BuildQuerySql(&qx)
qx.In = [][]any{in}
if er != nil {
err = er
@ -55,7 +55,7 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page
Fields: "count(*) n",
}
}
n, err := gets[count[T]](db, ctx, qx)
n, err := gets[count[T]](db, ctx, &qx)
total = n.N
if err != nil || total < 1 {
return
@ -77,21 +77,21 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page
return
}
func paginationToMap[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
func paginationToMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
ctx = context.WithValue(ctx, "handle=>toMap", &r)
_, total, err = pagination[T](db, ctx, q, page, pageSize)
return
}
func PaginationToMap[T Model](ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
func PaginationToMap[T Model](ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
return paginationToMap[T](globalBb, ctx, q, page, pageSize)
}
func PaginationToMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
func PaginationToMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) {
return paginationToMap[T](db, ctx, q, page, pageSize)
}
func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) {
return gets[T](globalBb, ctx, QueryCondition{
return gets[T](globalBb, ctx, &QueryCondition{
Fields: "*",
Where: SqlBuilder{
{PrimaryKey[T](), "=", number.IntToString(id), "int"},
@ -119,7 +119,7 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in .
}
func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) {
s, args, err := BuildQuerySql(QueryCondition{
s, args, err := BuildQuerySql(&QueryCondition{
Where: where,
Fields: fields,
In: in,
@ -144,7 +144,7 @@ func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error
}
func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) {
q := QueryCondition{
q := &QueryCondition{
Where: where,
Fields: fields,
Group: group,

View File

@ -41,6 +41,12 @@ type post struct {
PostMeta *[]models.PostMeta `table:"wp_postmeta meta" foreignKey:"post_id" local:"ID" relation:"hasMany"`
}
type TermRelationships struct {
ObjectID uint64 `db:"object_id"`
TermTaxonomyId uint64 `db:"term_taxonomy_id"`
TermOrder int64 `db:"term_order"`
}
type user struct {
Id uint64 `gorm:"column:ID" db:"ID" json:"ID"`
UserLogin string `gorm:"column:user_login" db:"user_login" json:"user_login"`
@ -329,6 +335,30 @@ func TestFindOneById(t *testing.T) {
}
}
func TestGets2(t *testing.T) {
t.Run("hasOne", func(t *testing.T) {
{
q := Conditions(
Where(SqlBuilder{{"id = 190"}}),
With("user", Conditions(
Fields("ID,user_login,user_pass"),
)),
Fields("posts.*"),
From("wp_posts posts"),
With("meta", Conditions(
WithJoin(true),
)),
)
ctx = context.WithValue(ctx, "ancestorsQueryCondition", q)
got, err := Gets[post](ctx, q)
_ = got
if err != nil {
t.Errorf("err:%v", err)
}
}
})
}
func TestFirstOne(t *testing.T) {
type args struct {
where ParseWhere
@ -365,18 +395,6 @@ func TestFirstOne(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...)
gott, err := Gets[post](ctx, Conditions(
Where(SqlBuilder{{"post_status", "publish"}}),
Order([][]string{{"ID", "desc"}}),
With("user", WithConditions(
Fields("ID,user_login,user_pass"),
Where(SqlBuilder{
{"user.ID", ">", "0", "int"},
}),
)),
With("meta", nil),
))
_ = gott
if (err != nil) != tt.wantErr {
t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr)
return
@ -483,7 +501,7 @@ func Test_pagination(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
page int
pageSize int
}
@ -500,7 +518,7 @@ func Test_pagination(t *testing.T) {
args: args{
db: glob,
ctx: ctx,
q: QueryCondition{
q: &QueryCondition{
Fields: "post_type,count(*) ID",
Group: "post_type",
Having: SqlBuilder{{"ID", ">", "1", "int"}},
@ -541,7 +559,7 @@ func Test_paginationToMap(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
page int
pageSize int
}
@ -557,7 +575,7 @@ func Test_paginationToMap(t *testing.T) {
args: args{
db: glob,
ctx: ctx,
q: QueryCondition{
q: &QueryCondition{
Fields: "ID",
Where: SqlBuilder{{"ID < 200"}},
},
@ -572,7 +590,7 @@ func Test_paginationToMap(t *testing.T) {
args: args{
db: glob,
ctx: ctx,
q: QueryCondition{
q: &QueryCondition{
Fields: "ID",
Where: SqlBuilder{{"ID < 200"}},
},

View File

@ -12,7 +12,7 @@ import (
// Finds 比 Find 多一个offset
//
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) {
func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](globalBb, ctx, q)
return
}
@ -20,13 +20,13 @@ func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) {
// FindFromDB 同 Finds 使用指定 db 查询
//
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func FindFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) {
func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](db, ctx, q)
return
}
func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) {
setTable[T](&q)
func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
setTable[T](q)
sq, args, err := BuildQuerySql(q)
if err != nil {
return
@ -36,7 +36,7 @@ func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, e
return
}
func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
i := 1
var rr []T
var total int
@ -65,7 +65,7 @@ func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCo
// ChunkFind 分片查询并直接返回所有结果
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](globalBb, ctx, perLimit, q)
return
}
@ -73,7 +73,7 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r
// ChunkFindFromDB 同 ChunkFind
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) {
func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](db, ctx, perLimit, q)
return
}
@ -81,7 +81,7 @@ func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q Q
// Chunk 分片查询并函数过虑返回新类型的切片
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(globalBb, ctx, perLimit, fn, q)
return
}
@ -89,12 +89,12 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R
// ChunkFromDB 同 Chunk
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(db, ctx, perLimit, fn, q)
return
}
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) {
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
i := 1
var rr []T
var count int
@ -130,25 +130,25 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun
// Pagination 同
//
// Condition 中可使用 Where Fields From Group Having Join Order Limit In 函数
func Pagination[T Model](ctx context.Context, q QueryCondition, page, pageSize int) ([]T, int, error) {
func Pagination[T Model](ctx context.Context, q *QueryCondition, page, pageSize int) ([]T, int, error) {
return pagination[T](globalBb, ctx, q, page, pageSize)
}
// PaginationFromDB 同 Pagination 方便多个db使用
//
// Condition 中可使用 Where Fields Group Having Join Order Limit In 函数
func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) ([]T, int, error) {
func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) ([]T, int, error) {
return pagination[T](db, ctx, q, page, pageSize)
}
func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q QueryCondition) ([]T, error) {
func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) {
return column[V, T](globalBb, ctx, fn, q)
}
func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) {
func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
return column[V, T](db, ctx, fn, q)
}
func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) {
func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) {
res, err := finds[V](db, ctx, q)
if err != nil {
return nil, err
@ -157,11 +157,11 @@ func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool
return
}
func GetField[T Model](ctx context.Context, field string, q QueryCondition) (r string, err error) {
func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) {
r, err = getField[T](globalBb, ctx, field, q)
return
}
func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) {
func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
if q.Fields == "" || q.Fields == "*" {
q.Fields = field
}
@ -176,12 +176,12 @@ func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCon
}
return
}
func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) {
func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) {
return getField[T](db, ctx, field, q)
}
func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) {
setTable[T](&q)
func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
setTable[T](q)
rawSql, in, err := BuildQuerySql(q)
if err != nil {
return nil, err
@ -190,13 +190,13 @@ func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition)
err = db.Get(ctx, &r, rawSql, in...)
return
}
func GetToStringMap[T Model](ctx context.Context, q QueryCondition) (r map[string]string, err error) {
func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](globalBb, ctx, q)
return
}
func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
setTable[T](&q)
func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
setTable[T](q)
rawSql, in, err := BuildQuerySql(q)
if err != nil {
return nil, err
@ -206,30 +206,30 @@ func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition)
return
}
func FindToStringMap[T Model](ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](globalBb, ctx, q)
return
}
func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) {
r, err = findToStringMap[T](db, ctx, q)
return
}
func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) {
func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) {
r, err = getToStringMap[T](db, ctx, q)
return
}
func BuildQuerySql(q QueryCondition) (r string, args []any, err error) {
w := ""
func BuildQuerySql(q *QueryCondition) (r string, args []any, err error) {
where := ""
if q.Where != nil {
w, args, err = q.Where.ParseWhere(&q.In)
where, args, err = q.Where.ParseWhere(&q.In)
if err != nil {
return
}
}
h := ""
having := ""
if q.Having != nil {
hh, arg, er := q.Having.ParseWhere(&q.In)
if er != nil {
@ -237,15 +237,17 @@ func BuildQuerySql(q QueryCondition) (r string, args []any, err error) {
return
}
args = append(args, arg...)
h = strings.Replace(hh, " where", " having", 1)
having = strings.Replace(hh, " where", " having", 1)
}
if len(args) == 0 && len(q.In) > 0 {
for _, antes := range q.In {
args = append(args, antes...)
}
}
j := q.Join.parseJoin()
join := ""
if q.Join != nil {
join = q.Join.parseJoin()
}
groupBy := ""
if q.Group != "" {
g := strings.Builder{}
@ -262,12 +264,16 @@ func BuildQuerySql(q QueryCondition) (r string, args []any, err error) {
if q.Offset > 0 {
l = fmt.Sprintf(" %s offset %d", l, q.Offset)
}
r = fmt.Sprintf(tp, q.Fields, table, j, w, groupBy, h, q.Order.parseOrderBy(), l)
order := ""
if q.Order != nil {
order = q.Order.parseOrderBy()
}
r = fmt.Sprintf(tp, q.Fields, table, join, where, groupBy, having, order, l)
return
}
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) {
setTable[T](&q)
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) {
setTable[T](q)
s, args, err := BuildQuerySql(q)
if err != nil {
return
@ -281,33 +287,59 @@ func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCo
return
}
func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) error {
func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error {
return findScanner[T](db, ctx, fn, q)
}
func FindScanner[T Model](ctx context.Context, fn func(T), q QueryCondition) error {
func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error {
return findScanner[T](globalBb, ctx, fn, q)
}
func Gets[T Model](ctx context.Context, q QueryCondition) (T, error) {
func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) {
return gets[T](globalBb, ctx, q)
}
func GetsFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (T, error) {
func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) {
return gets[T](db, ctx, q)
}
func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) {
setTable[T](&q)
func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) {
setTable[T](q)
if len(q.Relation) < 1 {
s, args, er := BuildQuerySql(q)
if er != nil {
err = er
return
}
err = db.Get(ctx, &r, s, args...)
return
}
err = parseRelation(false, db, ctx, &r, q)
return
}
func parseRelation(isMultiple bool, db dbQuery, ctx context.Context, r any, q *QueryCondition) (err error) {
fn, fns := Relation(db, ctx, r, q)
for _, f := range fn {
f()
}
s, args, err := BuildQuerySql(q)
if err != nil {
return
}
err = db.Get(ctx, &r, s, args...)
if isMultiple {
err = db.Select(ctx, r, s, args...)
} else {
err = db.Get(ctx, r, s, args...)
}
if err != nil {
return
}
for _, f := range fns {
err = f()
if err != nil {
return
}
if len(q.Relation) > 0 {
err = Relation[T](db, ctx, &r, &q)
}
return
}

View File

@ -15,7 +15,7 @@ import (
func TestFinds(t *testing.T) {
type args struct {
ctx context.Context
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string
@ -66,7 +66,7 @@ func TestChunkFind(t *testing.T) {
type args struct {
ctx context.Context
perLimit int
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string
@ -118,7 +118,7 @@ func TestChunk(t *testing.T) {
ctx context.Context
perLimit int
fn func(rows T) (R, bool)
q QueryCondition
q *QueryCondition
}
type testCase[T Model, R any] struct {
name string
@ -179,7 +179,7 @@ func TestChunk(t *testing.T) {
func TestPagination(t *testing.T) {
type args struct {
ctx context.Context
q QueryCondition
q *QueryCondition
page int
pageSize int
}
@ -238,7 +238,7 @@ func TestColumn(t *testing.T) {
type args[V Model, T any] struct {
ctx context.Context
fn func(V) (T, bool)
q QueryCondition
q *QueryCondition
}
type testCase[V Model, T any] struct {
name string
@ -354,7 +354,7 @@ func Test_getToStringMap(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
}
tests := []struct {
name string
@ -409,7 +409,7 @@ func Test_findToStringMap(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
}
tests := []struct {
name string
@ -484,7 +484,7 @@ func Test_findScanner(t *testing.T) {
db dbQuery
ctx context.Context
fn func(T)
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string
@ -547,7 +547,7 @@ func Test_gets(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string
@ -585,7 +585,7 @@ func Test_finds(t *testing.T) {
type args struct {
db dbQuery
ctx context.Context
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string
@ -648,7 +648,7 @@ func Test_finds(t *testing.T) {
func TestGets(t *testing.T) {
type args struct {
ctx context.Context
q QueryCondition
q *QueryCondition
}
type testCase[T Model] struct {
name string

View File

@ -3,6 +3,7 @@ package model
import (
"context"
"fmt"
"github.com/fthvgb1/wp-go/helper"
"reflect"
"strings"
)
@ -13,15 +14,19 @@ func setTable[T Model](q *QueryCondition) {
}
}
func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition) (err error) {
var rr T
t := reflect.TypeOf(rr)
func Relation(db dbQuery, ctx context.Context, r any, q *QueryCondition) ([]func(), []func() error) {
var fn []func()
var fns []func() error
t := reflect.TypeOf(r).Elem()
v := reflect.ValueOf(r).Elem()
for tableTag, relation := range q.Relation {
if tableTag == "" {
continue
}
tableTag := tableTag
relation := relation
for i := 0; i < t.NumField(); i++ {
i := i
tag := t.Field(i).Tag
table, ok := tag.Lookup("table")
if !ok || table == "" {
@ -31,6 +36,14 @@ func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition)
if tables[len(tables)-1] != tableTag {
continue
}
foreignKey := tag.Get("foreignKey")
if foreignKey == "" {
continue
}
localKey := tag.Get("local")
if localKey == "" {
continue
}
if relation == nil {
relation = &QueryCondition{
Fields: "*",
@ -42,10 +55,19 @@ func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition)
for ; j < t.NumField(); j++ {
vvv, ok := t.Field(j).Tag.Lookup("db")
if ok && vvv == tag.Get("local") {
id = fmt.Sprintf("%v", v.Field(j).Interface())
break
}
}
if relation.WithJoin {
from := strings.Split(q.From, " ")
fn = append(fn, func() {
qq := helper.GetContextVal(ctx, "ancestorsQueryCondition", q)
qq.Join = append(q.Join, SqlBuilder{
{"left join", table, fmt.Sprintf("%s.%s=%s.%s", tables[len(tables)-1], foreignKey, from[len(from)-1], localKey)},
}...)
})
}
fns = append(fns, func() error {
{
var w any = relation.Where
if w == nil {
@ -53,29 +75,28 @@ func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition)
}
ww, ok := w.(SqlBuilder)
if ok {
id = fmt.Sprintf("%v", v.Field(j).Interface())
ww = append(ww, SqlBuilder{{
tag.Get("foreignKey"), "=", id, "int",
foreignKey, "=", id, "int",
}}...)
relation.Where = ww
}
}
sq, args, er := BuildQuerySql(*relation)
if er != nil {
err = er
return
}
var err error
vv := reflect.New(v.Field(i).Type().Elem()).Interface()
switch tag.Get("relation") {
case "hasOne":
err = db.Get(ctx, vv, sq, args...)
err = parseRelation(false, db, ctx, vv, relation)
case "hasMany":
err = db.Select(ctx, vv, sq, args...)
err = parseRelation(true, db, ctx, vv, relation)
}
if err != nil {
return
return err
}
v.Field(i).Set(reflect.ValueOf(vv))
return nil
})
}
}
return
return fn, fns
}