hasone hasmany 待完善

This commit is contained in:
xing 2023-05-17 22:22:31 +08:00
parent a6ee333232
commit 044e55a399
5 changed files with 184 additions and 60 deletions

View File

@ -11,6 +11,7 @@ type QueryCondition struct {
Limit int Limit int
Offset int Offset int
In [][]any In [][]any
Relation map[string]*QueryCondition
} }
func Conditions(fns ...Condition) QueryCondition { func Conditions(fns ...Condition) QueryCondition {
@ -23,6 +24,16 @@ func Conditions(fns ...Condition) QueryCondition {
} }
return r return r
} }
func WithConditions(fns ...Condition) *QueryCondition {
r := QueryCondition{}
for _, fn := range fns {
fn(&r)
}
if r.Fields == "" {
r.Fields = "*"
}
return &r
}
type Condition func(c *QueryCondition) type Condition func(c *QueryCondition)
@ -84,3 +95,12 @@ func In(in ...[]any) Condition {
c.In = append(c.In, in...) c.In = append(c.In, in...)
} }
} }
func With(tableTag string, q *QueryCondition) Condition {
return func(c *QueryCondition) {
if c.Relation == nil {
c.Relation = map[string]*QueryCondition{}
}
c.Relation[tableTag] = q
}
}

View File

@ -39,7 +39,10 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page
} }
if q.Group != "" { if q.Group != "" {
qx.Fields = q.Fields qx.Fields = q.Fields
sq, in, er := BuildQuerySql[T](qx) if qx.From == "" {
qx.From = Table[T]()
}
sq, in, er := BuildQuerySql(qx)
qx.In = [][]any{in} qx.In = [][]any{in}
if er != nil { if er != nil {
err = er err = er
@ -91,24 +94,18 @@ func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T,
return gets[T](globalBb, ctx, QueryCondition{ return gets[T](globalBb, ctx, QueryCondition{
Fields: "*", Fields: "*",
Where: SqlBuilder{ Where: SqlBuilder{
{PrimaryKey[T](), "=", number.ToString(id), "int"}, {PrimaryKey[T](), "=", number.IntToString(id), "int"},
}, },
}) })
} }
func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) { func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (T, error) {
s, args, err := BuildQuerySql[T](QueryCondition{ return gets[T](globalBb, ctx, Conditions(
Where: where, Where(where),
Fields: fields, Fields(fields),
Order: order, Order(order),
In: in, In(in...),
Limit: 1, ))
})
if err != nil {
return
}
err = globalBb.Get(ctx, &r, s, args...)
return
} }
func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) { func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) {
@ -122,10 +119,11 @@ func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in .
} }
func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) {
s, args, err := BuildQuerySql[T](QueryCondition{ s, args, err := BuildQuerySql(QueryCondition{
Where: where, Where: where,
Fields: fields, Fields: fields,
In: in, In: in,
From: Table[T](),
}) })
if err != nil { if err != nil {
return return
@ -155,8 +153,9 @@ func Find[T Model](ctx context.Context, where ParseWhere, fields, group string,
Having: having, Having: having,
Limit: limit, Limit: limit,
In: in, In: in,
From: Table[T](),
} }
s, args, err := BuildQuerySql[T](q) s, args, err := BuildQuerySql(q)
if err != nil { if err != nil {
return return
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/fthvgb1/wp-go/app/pkg/models"
"github.com/fthvgb1/wp-go/safety" "github.com/fthvgb1/wp-go/safety"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -36,6 +37,8 @@ type post struct {
PostType string `gorm:"column:post_type" db:"post_type" json:"post_type" form:"post_type"` PostType string `gorm:"column:post_type" db:"post_type" json:"post_type" form:"post_type"`
PostMimeType string `gorm:"column:post_mime_type" db:"post_mime_type" json:"post_mime_type" form:"post_mime_type"` PostMimeType string `gorm:"column:post_mime_type" db:"post_mime_type" json:"post_mime_type" form:"post_mime_type"`
CommentCount int64 `gorm:"column:comment_count" db:"comment_count" json:"comment_count" form:"comment_count"` CommentCount int64 `gorm:"column:comment_count" db:"comment_count" json:"comment_count" form:"comment_count"`
User *user `table:"wp_users user" foreignKey:"ID" local:"post_author" relation:"hasOne"`
PostMeta *[]models.PostMeta `table:"wp_postmeta meta" foreignKey:"post_id" local:"ID" relation:"hasMany"`
} }
type user struct { type user struct {
@ -362,6 +365,18 @@ func TestFirstOne(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...) got, err := FirstOne[post](ctx, tt.args.where, tt.args.fields, tt.args.order, tt.args.in...)
gott, err := Gets[post](ctx, Conditions(
Where(SqlBuilder{{"post_status", "publish"}}),
Order([][]string{{"ID", "desc"}}),
With("user", WithConditions(
Fields("ID,user_login,user_pass"),
Where(SqlBuilder{
{"user.ID", ">", "0", "int"},
}),
)),
With("meta", nil),
))
_ = gott
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("FirstOne() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@ -26,11 +26,13 @@ func FindFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r [
} }
func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) { func finds[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []T, err error) {
sq, args, err := BuildQuerySql[T](q) setTable[T](&q)
sq, args, err := BuildQuerySql(q)
if err != nil { if err != nil {
return return
} }
err = db.Select(ctx, &r, sq, args...) err = db.Select(ctx, &r, sq, args...)
return return
} }
@ -179,7 +181,8 @@ func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q Qu
} }
func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) { func getToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q) setTable[T](&q)
rawSql, in, err := BuildQuerySql(q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -193,7 +196,8 @@ func GetToStringMap[T Model](ctx context.Context, q QueryCondition) (r map[strin
} }
func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) { func findToStringMap[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r []map[string]string, err error) {
rawSql, in, err := BuildQuerySql[T](q) setTable[T](&q)
rawSql, in, err := BuildQuerySql(q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,7 +221,7 @@ func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondi
return return
} }
func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error) { func BuildQuerySql(q QueryCondition) (r string, args []any, err error) {
w := "" w := ""
if q.Where != nil { if q.Where != nil {
w, args, err = q.Where.ParseWhere(&q.In) w, args, err = q.Where.ParseWhere(&q.In)
@ -251,10 +255,7 @@ func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error)
} }
tp := "select %s from %s %s %s %s %s %s %s" tp := "select %s from %s %s %s %s %s %s %s"
l := "" l := ""
table := Table[T]() table := q.From
if q.From != "" {
table = q.From
}
if q.Limit > 0 { if q.Limit > 0 {
l = fmt.Sprintf(" limit %d", q.Limit) l = fmt.Sprintf(" limit %d", q.Limit)
} }
@ -266,7 +267,8 @@ func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error)
} }
func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) { func findScanner[T Model](db dbQuery, ctx context.Context, fn func(T), q QueryCondition) (err error) {
s, args, err := BuildQuerySql[T](q) setTable[T](&q)
s, args, err := BuildQuerySql(q)
if err != nil { if err != nil {
return return
} }
@ -295,10 +297,17 @@ func GetsFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondition) (T,
} }
func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) { func gets[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r T, err error) {
s, args, err := BuildQuerySql[T](q) setTable[T](&q)
s, args, err := BuildQuerySql(q)
if err != nil { if err != nil {
return return
} }
err = db.Get(ctx, &r, s, args...) err = db.Get(ctx, &r, s, args...)
if err != nil {
return
}
if len(q.Relation) > 0 {
err = Relation[T](db, ctx, &r, &q)
}
return return
} }

81
model/relation.go Normal file
View File

@ -0,0 +1,81 @@
package model
import (
"context"
"fmt"
"reflect"
"strings"
)
func setTable[T Model](q *QueryCondition) {
if q.From == "" {
q.From = Table[T]()
}
}
func Relation[T Model](db dbQuery, ctx context.Context, r *T, q *QueryCondition) (err error) {
var rr T
t := reflect.TypeOf(rr)
v := reflect.ValueOf(r).Elem()
for tableTag, relation := range q.Relation {
if tableTag == "" {
continue
}
for i := 0; i < t.NumField(); i++ {
tag := t.Field(i).Tag
table, ok := tag.Lookup("table")
if !ok || table == "" {
continue
}
tables := strings.Split(table, " ")
if tables[len(tables)-1] != tableTag {
continue
}
if relation == nil {
relation = &QueryCondition{
Fields: "*",
}
}
relation.From = table
id := ""
j := 0
for ; j < t.NumField(); j++ {
vvv, ok := t.Field(j).Tag.Lookup("db")
if ok && vvv == tag.Get("local") {
id = fmt.Sprintf("%v", v.Field(j).Interface())
break
}
}
{
var w any = relation.Where
if w == nil {
w = SqlBuilder{}
}
ww, ok := w.(SqlBuilder)
if ok {
ww = append(ww, SqlBuilder{{
tag.Get("foreignKey"), "=", id, "int",
}}...)
relation.Where = ww
}
}
sq, args, er := BuildQuerySql(*relation)
if er != nil {
err = er
return
}
vv := reflect.New(v.Field(i).Type().Elem()).Interface()
switch tag.Get("relation") {
case "hasOne":
err = db.Get(ctx, vv, sq, args...)
case "hasMany":
err = db.Select(ctx, vv, sq, args...)
}
if err != nil {
return
}
v.Field(i).Set(reflect.ValueOf(vv))
}
}
return
}