diff --git a/helper/func.go b/helper/func.go index 71a81b5..5493b87 100644 --- a/helper/func.go +++ b/helper/func.go @@ -114,3 +114,9 @@ func GetContextVal[V, K any](ctx context.Context, k K, defaults V) V { } return vv } + +func IsImplements[T, A any](i A) (T, bool) { + var a any = i + t, ok := a.(T) + return t, ok +} diff --git a/helper/slice/slices.go b/helper/slice/slices.go index e0056bf..b86e502 100644 --- a/helper/slice/slices.go +++ b/helper/slice/slices.go @@ -71,3 +71,23 @@ func Unshift[T any](a *[]T, e ...T) { func Push[T any](a *[]T, e ...T) { *a = append(*a, e...) } + +func Decompress[T any](a [][]T) (r []T) { + for _, ts := range a { + for _, t := range ts { + r = append(r, t) + } + } + return +} +func DecompressBy[T, R any](a [][]T, fn func(T) (R, bool)) (r []R) { + for _, ts := range a { + for _, t := range ts { + v, ok := fn(t) + if ok { + r = append(r, v) + } + } + } + return +} diff --git a/model/model.go b/model/model.go index 0a426fe..f50d4af 100644 --- a/model/model.go +++ b/model/model.go @@ -22,6 +22,10 @@ type ParseWhere interface { ParseWhere(*[][]any) (string, []any, error) } +type AndWhere interface { + AndWhere(field, operator, val, fieldType string) ParseWhere +} + type dbQuery interface { Select(context.Context, any, string, ...any) error Get(context.Context, any, string, ...any) error diff --git a/model/parse.go b/model/parse.go index fedf647..a878294 100644 --- a/model/parse.go +++ b/model/parse.go @@ -186,3 +186,10 @@ func (w SqlBuilder) parseJoin() string { } return s.String() } + +func (w SqlBuilder) AndWhere(field, operator, val, fieldType string) ParseWhere { + ww := append(w, []string{ + field, operator, val, fieldType, + }) + return ww +} diff --git a/model/relation.go b/model/relation.go index 0414e78..c976266 100644 --- a/model/relation.go +++ b/model/relation.go @@ -24,7 +24,7 @@ const ( // Relationship join table // -// RelationType HasOne| HasMany +// # RelationType HasOne| HasMany // // eg: hasOne, post has a user. ForeignKey is user's id , Local is post's userId field // @@ -90,6 +90,18 @@ func parseAfterJoin(fromTable string, ids [][]any, qq *QueryCondition, ship Rela tables[len(tables)-1], ship.Middle.ForeignKey), "in", ""}, ) qq.Where = ww + } else { + aw, ok := helper.IsImplements[AndWhere](qq.Where) + if ok { + vv := aw.AndWhere(fmt.Sprintf("%s.%s", + tables[len(tables)-1], ship.Middle.ForeignKey), "in", strings.Join(slice.DecompressBy(ids, func(t any) (string, bool) { + return fmt.Sprintf("%v", t), true + }), ","), "int") + wa, ok := helper.IsImplements[ParseWhere](vv) + if ok { + qq.Where = wa + } + } } if qq.Fields == "" || qq.Fields == "*" { qq.Fields = str.Join(from[len(from)-1], ".", "*", ",", tables[len(tables)-1], ".", ship.Middle.ForeignKey) @@ -133,18 +145,26 @@ func Relation(isPlural bool, db dbQuery, ctx context.Context, r any, q *QueryCon if w == nil { qq.Where = SqlBuilder{} } - ww, ok := qq.Where.(SqlBuilder) in := [][]any{ids} - if ok { - if ship.Middle != nil { - isPlural = parseAfterJoin(qq.From, in, qq, ship) - } else { + if ship.Middle != nil { + isPlural = parseAfterJoin(qq.From, in, qq, ship) + } else { + ww, ok := qq.Where.(SqlBuilder) + if ok { ww = append(ww, SqlBuilder{{ ship.ForeignKey, "in", "", }}...) - qq.In = in qq.Where = ww + } else { + aw, ok := helper.IsImplements[AndWhere](qq.Where) + if ok { + ww := aw.AndWhere(ship.ForeignKey, "in", strings.Join(slice.Map(ids, func(t any) string { + return fmt.Sprintf("%v", t) + }), ","), "int") + qq.Where = ww + } } + qq.In = in } err = ParseRelation(isPlural || ship.RelationType == HasMany, db, ctx, helper.Or(isPlural, rrs, rr), qq) if err != nil {