diff --git a/app/pkg/dao/posts.go b/app/pkg/dao/posts.go index 6eaaa03..4016e53 100644 --- a/app/pkg/dao/posts.go +++ b/app/pkg/dao/posts.go @@ -94,7 +94,7 @@ func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { func SearchPostIds(args ...any) (ids PostIds, err error) { ctx := args[0].(context.Context) - q := args[1].(model.QueryCondition) + q := args[1].(*model.QueryCondition) page := args[2].(int) pageSize := args[3].(int) q.Fields = "ID" diff --git a/app/theme/wp/index.go b/app/theme/wp/index.go index cea2210..5ad9db1 100644 --- a/app/theme/wp/index.go +++ b/app/theme/wp/index.go @@ -79,7 +79,7 @@ func (i *IndexHandle) ParseIndex(parm *IndexParams) (err error) { func (i *IndexHandle) GetIndexData() (posts []models.Posts, totalRaw int, err error) { - q := model.QueryCondition{ + q := &model.QueryCondition{ Where: i.Param.Where, Order: model.SqlBuilder{{i.Param.OrderBy, i.Param.Order}}, Join: i.Param.Join, diff --git a/model/condition.go b/model/condition.go index efa0cb6..5b8cf7a 100644 --- a/model/condition.go +++ b/model/condition.go @@ -12,28 +12,19 @@ type QueryCondition struct { Offset int In [][]any Relation map[string]*QueryCondition + WithJoin bool } -func Conditions(fns ...Condition) QueryCondition { - r := QueryCondition{} +func Conditions(fns ...Condition) *QueryCondition { + r := &QueryCondition{} for _, fn := range fns { - fn(&r) + fn(r) } if r.Fields == "" { r.Fields = "*" } return r } -func WithConditions(fns ...Condition) *QueryCondition { - r := QueryCondition{} - for _, fn := range fns { - fn(&r) - } - if r.Fields == "" { - r.Fields = "*" - } - return &r -} type Condition func(c *QueryCondition) @@ -104,3 +95,9 @@ func With(tableTag string, q *QueryCondition) Condition { c.Relation[tableTag] = q } } + +func WithJoin(isJoin bool) Condition { + return func(c *QueryCondition) { + c.WithJoin = isJoin + } +} diff --git a/model/query.go b/model/query.go index 1cffd9e..6689bc0 100644 --- a/model/query.go +++ b/model/query.go @@ -23,7 +23,7 @@ func (c count[T]) Table() string { return c.t.Table() } -func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []T, total int, err error) { +func pagination[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []T, total int, err error) { if page < 1 || pageSize < 1 { return } @@ -42,7 +42,7 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page if qx.From == "" { qx.From = Table[T]() } - sq, in, er := BuildQuerySql(qx) + sq, in, er := BuildQuerySql(&qx) qx.In = [][]any{in} if er != nil { err = er @@ -55,7 +55,7 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page Fields: "count(*) n", } } - n, err := gets[count[T]](db, ctx, qx) + n, err := gets[count[T]](db, ctx, &qx) total = n.N if err != nil || total < 1 { return @@ -77,21 +77,21 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page return } -func paginationToMap[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { +func paginationToMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { ctx = context.WithValue(ctx, "handle=>toMap", &r) _, total, err = pagination[T](db, ctx, q, page, pageSize) return } -func PaginationToMap[T Model](ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { +func PaginationToMap[T Model](ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { return paginationToMap[T](globalBb, ctx, q, page, pageSize) } -func PaginationToMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { +func PaginationToMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) (r []map[string]string, total int, err error) { return paginationToMap[T](db, ctx, q, page, pageSize) } func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { - return gets[T](globalBb, ctx, QueryCondition{ + return gets[T](globalBb, ctx, &QueryCondition{ Fields: "*", Where: SqlBuilder{ {PrimaryKey[T](), "=", number.IntToString(id), "int"}, @@ -119,7 +119,7 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in . } func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { - s, args, err := BuildQuerySql(QueryCondition{ + s, args, err := BuildQuerySql(&QueryCondition{ Where: where, Fields: fields, In: in, @@ -144,7 +144,7 @@ func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error } func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { - q := QueryCondition{ + q := &QueryCondition{ Where: where, Fields: fields, Group: group, diff --git a/model/query_test.go b/model/query_test.go index 87a6e13..671e5f8 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -41,6 +41,12 @@ type post struct { PostMeta *[]models.PostMeta `table:"wp_postmeta meta" foreignKey:"post_id" local:"ID" relation:"hasMany"` } +type TermRelationships struct { + ObjectID uint64 `db:"object_id"` + TermTaxonomyId uint64 `db:"term_taxonomy_id"` + TermOrder int64 `db:"term_order"` +} + type user struct { Id uint64 `gorm:"column:ID" db:"ID" json:"ID"` UserLogin string `gorm:"column:user_login" db:"user_login" json:"user_login"` @@ -329,6 +335,30 @@ func TestFindOneById(t *testing.T) { } } +func TestGets2(t *testing.T) { + t.Run("hasOne", func(t *testing.T) { + { + q := Conditions( + Where(SqlBuilder{{"id = 190"}}), + With("user", Conditions( + Fields("ID,user_login,user_pass"), + )), + Fields("posts.*"), + From("wp_posts posts"), + With("meta", Conditions( + WithJoin(true), + )), + ) + ctx = context.WithValue(ctx, "ancestorsQueryCondition", q) + got, err := Gets[post](ctx, q) + _ = got + if err != nil { + t.Errorf("err:%v", err) + } + } + }) +} + func TestFirstOne(t *testing.T) { type args struct { where ParseWhere @@ -365,18 +395,6 @@ func TestFirstOne(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...) - gott, err := Gets[post](ctx, Conditions( - Where(SqlBuilder{{"post_status", "publish"}}), - Order([][]string{{"ID", "desc"}}), - With("user", WithConditions( - Fields("ID,user_login,user_pass"), - Where(SqlBuilder{ - {"user.ID", ">", "0", "int"}, - }), - )), - With("meta", nil), - )) - _ = gott if (err != nil) != tt.wantErr { t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr) return @@ -483,7 +501,7 @@ func Test_pagination(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition page int pageSize int } @@ -500,7 +518,7 @@ func Test_pagination(t *testing.T) { args: args{ db: glob, ctx: ctx, - q: QueryCondition{ + q: &QueryCondition{ Fields: "post_type,count(*) ID", Group: "post_type", Having: SqlBuilder{{"ID", ">", "1", "int"}}, @@ -541,7 +559,7 @@ func Test_paginationToMap(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition page int pageSize int } @@ -557,7 +575,7 @@ func Test_paginationToMap(t *testing.T) { args: args{ db: glob, ctx: ctx, - q: QueryCondition{ + q: &QueryCondition{ Fields: "ID", Where: SqlBuilder{{"ID < 200"}}, }, @@ -572,7 +590,7 @@ func Test_paginationToMap(t *testing.T) { args: args{ db: glob, ctx: ctx, - q: QueryCondition{ + q: &QueryCondition{ Fields: "ID", Where: SqlBuilder{{"ID < 200"}}, }, diff --git a/model/querycondition.go b/model/querycondition.go index a589d0e..6fd3615 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -12,7 +12,7 @@ import ( // Finds 比 Find 多一个offset // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) { +func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { r, err = finds[T](globalBb, ctx, q) return } @@ -20,13 +20,13 @@ func Finds[T Model](ctx context.Context, q QueryCondition) (r []T, err error) { // FindFromDB 同 Finds 使用指定 db 查询 // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 -func FindFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) { +func FindFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { r, err = finds[T](db, ctx, q) return } -func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) { - setTable[T](&q) +func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { + setTable[T](q) sq, args, err := BuildQuerySql(q) if err != nil { return @@ -36,7 +36,7 @@ func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, e return } -func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { +func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { i := 1 var rr []T var total int @@ -65,7 +65,7 @@ func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCo // ChunkFind 分片查询并直接返回所有结果 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { +func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { r, err = chunkFind[T](globalBb, ctx, perLimit, q) return } @@ -73,7 +73,7 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q QueryCondition) (r // ChunkFindFromDB 同 ChunkFind // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q QueryCondition) (r []T, err error) { +func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { r, err = chunkFind[T](db, ctx, perLimit, q) return } @@ -81,7 +81,7 @@ func ChunkFindFromDB[T Model](db dbQuery, ctx context.Context, perLimit int, q Q // Chunk 分片查询并函数过虑返回新类型的切片 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { +func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { r, err = chunk(globalBb, ctx, perLimit, fn, q) return } @@ -89,12 +89,12 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R // ChunkFromDB 同 Chunk // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { +func ChunkFromDB[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { r, err = chunk(db, ctx, perLimit, fn, q) return } -func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q QueryCondition) (r []R, err error) { +func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { i := 1 var rr []T var count int @@ -130,25 +130,25 @@ func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn fun // Pagination 同 // // Condition 中可使用 Where Fields From Group Having Join Order Limit In 函数 -func Pagination[T Model](ctx context.Context, q QueryCondition, page, pageSize int) ([]T, int, error) { +func Pagination[T Model](ctx context.Context, q *QueryCondition, page, pageSize int) ([]T, int, error) { return pagination[T](globalBb, ctx, q, page, pageSize) } // PaginationFromDB 同 Pagination 方便多个db使用 // // Condition 中可使用 Where Fields Group Having Join Order Limit In 函数 -func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) ([]T, int, error) { +func PaginationFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition, page, pageSize int) ([]T, int, error) { return pagination[T](db, ctx, q, page, pageSize) } -func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q QueryCondition) ([]T, error) { +func Column[V Model, T any](ctx context.Context, fn func(V) (T, bool), q *QueryCondition) ([]T, error) { return column[V, T](globalBb, ctx, fn, q) } -func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) { +func ColumnFromDB[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { return column[V, T](db, ctx, fn, q) } -func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q QueryCondition) (r []T, err error) { +func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool), q *QueryCondition) (r []T, err error) { res, err := finds[V](db, ctx, q) if err != nil { return nil, err @@ -157,11 +157,11 @@ func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool return } -func GetField[T Model](ctx context.Context, field string, q QueryCondition) (r string, err error) { +func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) { r, err = getField[T](globalBb, ctx, field, q) return } -func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) { +func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { if q.Fields == "" || q.Fields == "*" { q.Fields = field } @@ -176,12 +176,12 @@ func getField[T Model](db dbQuery, ctx context.Context, field string, q QueryCon } return } -func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q QueryCondition) (r string, err error) { +func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { return getField[T](db, ctx, field, q) } -func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) { - setTable[T](&q) +func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { + setTable[T](q) rawSql, in, err := BuildQuerySql(q) if err != nil { return nil, err @@ -190,13 +190,13 @@ func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) err = db.Get(ctx, &r, rawSql, in...) return } -func GetToStringMap[T Model](ctx context.Context, q QueryCondition) (r map[string]string, err error) { +func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) { r, err = getToStringMap[T](globalBb, ctx, q) return } -func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) { - setTable[T](&q) +func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { + setTable[T](q) rawSql, in, err := BuildQuerySql(q) if err != nil { return nil, err @@ -206,30 +206,30 @@ func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) return } -func FindToStringMap[T Model](ctx context.Context, q QueryCondition) (r []map[string]string, err error) { +func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { r, err = findToStringMap[T](globalBb, ctx, q) return } -func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) { +func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { r, err = findToStringMap[T](db, ctx, q) return } -func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) { +func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { r, err = getToStringMap[T](db, ctx, q) return } -func BuildQuerySql(q QueryCondition) (r string, args []any, err error) { - w := "" +func BuildQuerySql(q *QueryCondition) (r string, args []any, err error) { + where := "" if q.Where != nil { - w, args, err = q.Where.ParseWhere(&q.In) + where, args, err = q.Where.ParseWhere(&q.In) if err != nil { return } } - h := "" + having := "" if q.Having != nil { hh, arg, er := q.Having.ParseWhere(&q.In) if er != nil { @@ -237,15 +237,17 @@ func BuildQuerySql(q QueryCondition) (r string, args []any, err error) { return } args = append(args, arg...) - h = strings.Replace(hh, " where", " having", 1) + having = strings.Replace(hh, " where", " having", 1) } if len(args) == 0 && len(q.In) > 0 { for _, antes := range q.In { args = append(args, antes...) } } - - j := q.Join.parseJoin() + join := "" + if q.Join != nil { + join = q.Join.parseJoin() + } groupBy := "" if q.Group != "" { g := strings.Builder{} @@ -262,12 +264,16 @@ func BuildQuerySql(q QueryCondition) (r string, args []any, err error) { if q.Offset > 0 { l = fmt.Sprintf(" %s offset %d", l, q.Offset) } - r = fmt.Sprintf(tp, q.Fields, table, j, w, groupBy, h, q.Order.parseOrderBy(), l) + order := "" + if q.Order != nil { + order = q.Order.parseOrderBy() + } + r = fmt.Sprintf(tp, q.Fields, table, join, where, groupBy, having, order, l) return } -func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) { - setTable[T](&q) +func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) (err error) { + setTable[T](q) s, args, err := BuildQuerySql(q) if err != nil { return @@ -281,33 +287,59 @@ func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCo return } -func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) error { +func FindScannerFromDB[T Model](db dbQuery, ctx context.Context, fn func(T), q *QueryCondition) error { return findScanner[T](db, ctx, fn, q) } -func FindScanner[T Model](ctx context.Context, fn func(T), q QueryCondition) error { +func FindScanner[T Model](ctx context.Context, fn func(T), q *QueryCondition) error { return findScanner[T](globalBb, ctx, fn, q) } -func Gets[T Model](ctx context.Context, q QueryCondition) (T, error) { +func Gets[T Model](ctx context.Context, q *QueryCondition) (T, error) { return gets[T](globalBb, ctx, q) } -func GetsFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (T, error) { +func GetsFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (T, error) { return gets[T](db, ctx, q) } -func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) { - setTable[T](&q) +func gets[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r T, err error) { + setTable[T](q) + if len(q.Relation) < 1 { + s, args, er := BuildQuerySql(q) + if er != nil { + err = er + return + } + err = db.Get(ctx, &r, s, args...) + return + } + err = parseRelation(false, db, ctx, &r, q) + return +} + +func parseRelation(isMultiple bool, db dbQuery, ctx context.Context, r any, q *QueryCondition) (err error) { + fn, fns := Relation(db, ctx, r, q) + for _, f := range fn { + f() + } s, args, err := BuildQuerySql(q) if err != nil { return } - err = db.Get(ctx, &r, s, args...) + if isMultiple { + err = db.Select(ctx, r, s, args...) + } else { + err = db.Get(ctx, r, s, args...) + } + if err != nil { return } - if len(q.Relation) > 0 { - err = Relation[T](db, ctx, &r, &q) + for _, f := range fns { + err = f() + if err != nil { + return + } } return } diff --git a/model/querycondition_test.go b/model/querycondition_test.go index a6a8b79..faf4e14 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -15,7 +15,7 @@ import ( func TestFinds(t *testing.T) { type args struct { ctx context.Context - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string @@ -66,7 +66,7 @@ func TestChunkFind(t *testing.T) { type args struct { ctx context.Context perLimit int - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string @@ -118,7 +118,7 @@ func TestChunk(t *testing.T) { ctx context.Context perLimit int fn func(rows T) (R, bool) - q QueryCondition + q *QueryCondition } type testCase[T Model, R any] struct { name string @@ -179,7 +179,7 @@ func TestChunk(t *testing.T) { func TestPagination(t *testing.T) { type args struct { ctx context.Context - q QueryCondition + q *QueryCondition page int pageSize int } @@ -238,7 +238,7 @@ func TestColumn(t *testing.T) { type args[V Model, T any] struct { ctx context.Context fn func(V) (T, bool) - q QueryCondition + q *QueryCondition } type testCase[V Model, T any] struct { name string @@ -354,7 +354,7 @@ func Test_getToStringMap(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition } tests := []struct { name string @@ -409,7 +409,7 @@ func Test_findToStringMap(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition } tests := []struct { name string @@ -484,7 +484,7 @@ func Test_findScanner(t *testing.T) { db dbQuery ctx context.Context fn func(T) - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string @@ -547,7 +547,7 @@ func Test_gets(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string @@ -585,7 +585,7 @@ func Test_finds(t *testing.T) { type args struct { db dbQuery ctx context.Context - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string @@ -648,7 +648,7 @@ func Test_finds(t *testing.T) { func TestGets(t *testing.T) { type args struct { ctx context.Context - q QueryCondition + q *QueryCondition } type testCase[T Model] struct { name string diff --git a/model/relation.go b/model/relation.go index 57257e4..f150c77 100644 --- a/model/relation.go +++ b/model/relation.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/fthvgb1/wp-go/helper" "reflect" "strings" ) @@ -13,15 +14,19 @@ func setTable[T Model](q *QueryCondition) { } } -func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition) (err error) { - var rr T - t := reflect.TypeOf(rr) +func Relation(db dbQuery, ctx context.Context, r any, q *QueryCondition) ([]func(), []func() error) { + var fn []func() + var fns []func() error + t := reflect.TypeOf(r).Elem() v := reflect.ValueOf(r).Elem() for tableTag, relation := range q.Relation { if tableTag == "" { continue } + tableTag := tableTag + relation := relation for i := 0; i < t.NumField(); i++ { + i := i tag := t.Field(i).Tag table, ok := tag.Lookup("table") if !ok || table == "" { @@ -31,6 +36,14 @@ func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition) if tables[len(tables)-1] != tableTag { continue } + foreignKey := tag.Get("foreignKey") + if foreignKey == "" { + continue + } + localKey := tag.Get("local") + if localKey == "" { + continue + } if relation == nil { relation = &QueryCondition{ Fields: "*", @@ -42,40 +55,48 @@ func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition) for ; j < t.NumField(); j++ { vvv, ok := t.Field(j).Tag.Lookup("db") if ok && vvv == tag.Get("local") { - id = fmt.Sprintf("%v", v.Field(j).Interface()) break } } - { - var w any = relation.Where - if w == nil { - w = SqlBuilder{} + if relation.WithJoin { + from := strings.Split(q.From, " ") + fn = append(fn, func() { + qq := helper.GetContextVal(ctx, "ancestorsQueryCondition", q) + qq.Join = append(q.Join, SqlBuilder{ + {"left join", table, fmt.Sprintf("%s.%s=%s.%s", tables[len(tables)-1], foreignKey, from[len(from)-1], localKey)}, + }...) + }) + } + fns = append(fns, func() error { + { + var w any = relation.Where + if w == nil { + w = SqlBuilder{} + } + ww, ok := w.(SqlBuilder) + if ok { + id = fmt.Sprintf("%v", v.Field(j).Interface()) + ww = append(ww, SqlBuilder{{ + foreignKey, "=", id, "int", + }}...) + relation.Where = ww + } } - ww, ok := w.(SqlBuilder) - if ok { - ww = append(ww, SqlBuilder{{ - tag.Get("foreignKey"), "=", id, "int", - }}...) - relation.Where = ww + var err error + vv := reflect.New(v.Field(i).Type().Elem()).Interface() + switch tag.Get("relation") { + case "hasOne": + err = parseRelation(false, db, ctx, vv, relation) + case "hasMany": + err = parseRelation(true, db, ctx, vv, relation) } - } - sq, args, er := BuildQuerySql(*relation) - if er != nil { - err = er - return - } - vv := reflect.New(v.Field(i).Type().Elem()).Interface() - switch tag.Get("relation") { - case "hasOne": - err = db.Get(ctx, vv, sq, args...) - case "hasMany": - err = db.Select(ctx, vv, sq, args...) - } - if err != nil { - return - } - v.Field(i).Set(reflect.ValueOf(vv)) + if err != nil { + return err + } + v.Field(i).Set(reflect.ValueOf(vv)) + return nil + }) } } - return + return fn, fns }