Compare commits

...

2 Commits

Author SHA1 Message Date
b6091c6b42 注释 2023-05-30 20:04:05 +08:00
f2d69196bc 加个接口 2023-05-30 20:02:52 +08:00
5 changed files with 63 additions and 6 deletions

View File

@ -114,3 +114,9 @@ func GetContextVal[V, K any](ctx context.Context, k K, defaults V) V {
} }
return vv return vv
} }
func IsImplements[T, A any](i A) (T, bool) {
var a any = i
t, ok := a.(T)
return t, ok
}

View File

@ -71,3 +71,23 @@ func Unshift[T any](a *[]T, e ...T) {
func Push[T any](a *[]T, e ...T) { func Push[T any](a *[]T, e ...T) {
*a = append(*a, e...) *a = append(*a, e...)
} }
func Decompress[T any](a [][]T) (r []T) {
for _, ts := range a {
for _, t := range ts {
r = append(r, t)
}
}
return
}
func DecompressBy[T, R any](a [][]T, fn func(T) (R, bool)) (r []R) {
for _, ts := range a {
for _, t := range ts {
v, ok := fn(t)
if ok {
r = append(r, v)
}
}
}
return
}

View File

@ -22,6 +22,10 @@ type ParseWhere interface {
ParseWhere(*[][]any) (string, []any, error) ParseWhere(*[][]any) (string, []any, error)
} }
type AndWhere interface {
AndWhere(field, operator, val, fieldType string) ParseWhere
}
type dbQuery interface { type dbQuery interface {
Select(context.Context, any, string, ...any) error Select(context.Context, any, string, ...any) error
Get(context.Context, any, string, ...any) error Get(context.Context, any, string, ...any) error

View File

@ -186,3 +186,10 @@ func (w SqlBuilder) parseJoin() string {
} }
return s.String() return s.String()
} }
func (w SqlBuilder) AndWhere(field, operator, val, fieldType string) ParseWhere {
ww := append(w, []string{
field, operator, val, fieldType,
})
return ww
}

View File

@ -90,6 +90,18 @@ func parseAfterJoin(fromTable string, ids [][]any, qq *QueryCondition, ship Rela
tables[len(tables)-1], ship.Middle.ForeignKey), "in", ""}, tables[len(tables)-1], ship.Middle.ForeignKey), "in", ""},
) )
qq.Where = ww qq.Where = ww
} else {
aw, ok := helper.IsImplements[AndWhere](qq.Where)
if ok {
vv := aw.AndWhere(fmt.Sprintf("%s.%s",
tables[len(tables)-1], ship.Middle.ForeignKey), "in", strings.Join(slice.DecompressBy(ids, func(t any) (string, bool) {
return fmt.Sprintf("%v", t), true
}), ","), "int")
wa, ok := helper.IsImplements[ParseWhere](vv)
if ok {
qq.Where = wa
}
}
} }
if qq.Fields == "" || qq.Fields == "*" { if qq.Fields == "" || qq.Fields == "*" {
qq.Fields = str.Join(from[len(from)-1], ".", "*", ",", tables[len(tables)-1], ".", ship.Middle.ForeignKey) qq.Fields = str.Join(from[len(from)-1], ".", "*", ",", tables[len(tables)-1], ".", ship.Middle.ForeignKey)
@ -133,18 +145,26 @@ func Relation(isPlural bool, db dbQuery, ctx context.Context, r any, q *QueryCon
if w == nil { if w == nil {
qq.Where = SqlBuilder{} qq.Where = SqlBuilder{}
} }
ww, ok := qq.Where.(SqlBuilder)
in := [][]any{ids} in := [][]any{ids}
if ok { if ship.Middle != nil {
if ship.Middle != nil { isPlural = parseAfterJoin(qq.From, in, qq, ship)
isPlural = parseAfterJoin(qq.From, in, qq, ship) } else {
} else { ww, ok := qq.Where.(SqlBuilder)
if ok {
ww = append(ww, SqlBuilder{{ ww = append(ww, SqlBuilder{{
ship.ForeignKey, "in", "", ship.ForeignKey, "in", "",
}}...) }}...)
qq.In = in
qq.Where = ww qq.Where = ww
} else {
aw, ok := helper.IsImplements[AndWhere](qq.Where)
if ok {
ww := aw.AndWhere(ship.ForeignKey, "in", strings.Join(slice.Map(ids, func(t any) string {
return fmt.Sprintf("%v", t)
}), ","), "int")
qq.Where = ww
}
} }
qq.In = in
} }
err = ParseRelation(isPlural || ship.RelationType == HasMany, db, ctx, helper.Or(isPlural, rrs, rr), qq) err = ParseRelation(isPlural || ship.RelationType == HasMany, db, ctx, helper.Or(isPlural, rrs, rr), qq)
if err != nil { if err != nil {