diff --git a/helper/func.go b/helper/func.go index 043d86a..b46c67b 100644 --- a/helper/func.go +++ b/helper/func.go @@ -1,6 +1,7 @@ package helper import ( + "context" "fmt" str "github.com/fthvgb1/wp-go/helper/strings" "net/url" @@ -94,3 +95,15 @@ func ToBool[T comparable](t T) bool { var vv T return vv != t } + +func GetValFromContext[V, K any](ctx context.Context, k K, defaults V) V { + v := ctx.Value(k) + if v == nil { + return defaults + } + vv, ok := v.(V) + if !ok { + return defaults + } + return vv +} diff --git a/helper/func_test.go b/helper/func_test.go index dc336b3..09ecc09 100644 --- a/helper/func_test.go +++ b/helper/func_test.go @@ -1,6 +1,7 @@ package helper import ( + "context" "fmt" "reflect" "testing" @@ -235,3 +236,34 @@ func TestIsZeros(t *testing.T) { } }) } + +func TestGetValFromContext(t *testing.T) { + type args[K any, V any] struct { + ctx context.Context + k K + defaults V + } + type testCase[K any, V any] struct { + name string + args args[K, V] + want V + } + tests := []testCase[string, int]{ + { + name: "t1", + args: args[string, int]{ + ctx: context.WithValue(context.Background(), "kk", 1), + k: "kk", + defaults: 0, + }, + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetValFromContext(tt.args.ctx, tt.args.k, tt.args.defaults); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetValFromContext() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/pkg/dao/common.go b/internal/pkg/dao/common.go index 7ddb116..3956838 100644 --- a/internal/pkg/dao/common.go +++ b/internal/pkg/dao/common.go @@ -2,6 +2,7 @@ package dao import ( "context" + "github.com/fthvgb1/wp-go/helper" "github.com/fthvgb1/wp-go/internal/pkg/constraints" "github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/wpconfig" @@ -32,11 +33,17 @@ func CategoriesAndTags(a ...any) (terms []models.TermsMy, err error) { in = []any{"post_tag"} } } + w := model.SqlBuilder{ + {"tt.taxonomy", "in", ""}, + } + if helper.GetValFromContext(ctx, "onlyTop", false) { + w = append(w, []string{"tt.parent", "=", "0", "int"}) + } + if !helper.GetValFromContext(ctx, "showCountZero", false) { + w = append(w, []string{"tt.count", ">", "0", "int"}) + } terms, err = model.Finds[models.TermsMy](ctx, model.Conditions( - model.Where(model.SqlBuilder{ - {"tt.count", ">", "0", "int"}, - {"tt.taxonomy", "in", ""}, - }), + model.Where(w), model.Fields("t.term_id"), model.Order(model.SqlBuilder{{"t.name", "asc"}}), model.Join(model.SqlBuilder{ diff --git a/model/query.go b/model/query.go index 0e468b1..f34b99b 100644 --- a/model/query.go +++ b/model/query.go @@ -2,6 +2,7 @@ package model import ( "context" + "github.com/fthvgb1/wp-go/helper" "github.com/fthvgb1/wp-go/helper/number" str "github.com/fthvgb1/wp-go/helper/strings" "golang.org/x/exp/constraints" @@ -64,20 +65,12 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page return } q.Offset = offset - m := ctx.Value("handle=>toMap") + m := helper.GetValFromContext[*[]map[string]string](ctx, "handle=>toMap", nil) if m == nil { r, err = finds[T](db, ctx, q) return } - mm, ok := m.(*[]map[string]string) - if ok { - mx, er := findToStringMap[T](db, ctx, q) - if er != nil { - err = er - return - } - *mm = mx - } + *m, err = findToStringMap[T](db, ctx, q) return } diff --git a/model/sqxquery.go b/model/sqxquery.go index 08cf450..fc57f3c 100644 --- a/model/sqxquery.go +++ b/model/sqxquery.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/fthvgb1/wp-go/helper" "github.com/fthvgb1/wp-go/helper/slice" str "github.com/fthvgb1/wp-go/helper/strings" "github.com/jmoiron/sqlx" @@ -35,31 +36,26 @@ func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) } func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error { - v := ctx.Value("handle=>") - if v != nil { - vv, ok := v.(string) - if ok && vv != "" { - switch vv { - case "string": - return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) - case "scanner": - fn := ctx.Value("fn") - return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any))) - } + v := helper.GetValFromContext(ctx, "handle=>", "") + if v != "" { + switch v { + case "string": + return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) + case "scanner": + fn := ctx.Value("fn") + return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any))) } } + return r.sqlx.Select(dest, sql, params...) } func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { - v := ctx.Value("handle=>") - if v != nil { - vv, ok := v.(string) - if ok && vv != "" { - switch vv { - case "string": - return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) - } + v := helper.GetValFromContext(ctx, "handle=>", "") + if v != "" { + switch v { + case "string": + return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) } } return r.sqlx.Get(dest, sql, params...)