Compare commits

...

2 Commits

Author SHA1 Message Date
341b7197b8 优化代码 2023-03-19 22:48:23 +08:00
26bdcb44ac 优化 2023-03-19 22:14:42 +08:00
7 changed files with 92 additions and 51 deletions

View File

@ -1,6 +1,7 @@
package helper package helper
import ( import (
"context"
"fmt" "fmt"
str "github.com/fthvgb1/wp-go/helper/strings" str "github.com/fthvgb1/wp-go/helper/strings"
"net/url" "net/url"
@ -94,3 +95,15 @@ func ToBool[T comparable](t T) bool {
var vv T var vv T
return vv != t return vv != t
} }
func GetValFromContext[V, K any](ctx context.Context, k K, defaults V) V {
v := ctx.Value(k)
if v == nil {
return defaults
}
vv, ok := v.(V)
if !ok {
return defaults
}
return vv
}

View File

@ -1,6 +1,7 @@
package helper package helper
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -235,3 +236,34 @@ func TestIsZeros(t *testing.T) {
} }
}) })
} }
func TestGetValFromContext(t *testing.T) {
type args[K any, V any] struct {
ctx context.Context
k K
defaults V
}
type testCase[K any, V any] struct {
name string
args args[K, V]
want V
}
tests := []testCase[string, int]{
{
name: "t1",
args: args[string, int]{
ctx: context.WithValue(context.Background(), "kk", 1),
k: "kk",
defaults: 0,
},
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetValFromContext(tt.args.ctx, tt.args.k, tt.args.defaults); !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetValFromContext() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -193,19 +193,15 @@ func RenderedHtml(t *template.Template, data map[string]any) (r string, err erro
return return
} }
func BuildOptions[T any, K comparable](a []T, selected K, fn func(T) (K, any), attr ...string) string { func BuildOptions[T any, K comparable](a []T, selected K, fn func(T) (K, any, string)) string {
s := strings2.NewBuilder() s := strings2.NewBuilder()
att := ""
if len(attr) > 0 {
att = strings.Join(attr, " ")
}
for _, t := range a { for _, t := range a {
k, v := fn(t) k, v, attr := fn(t)
ss := "" ss := ""
if k == selected { if k == selected {
ss = "selected" ss = "selected"
} }
s.Sprintf(`<option %s %s value="%v">%v</option>`, ss, att, v, k) s.Sprintf(`<option %s %s value="%v">%v</option>`, ss, attr, v, k)
} }
return s.String() return s.String()
} }

View File

@ -2,6 +2,7 @@ package dao
import ( import (
"context" "context"
"github.com/fthvgb1/wp-go/helper"
"github.com/fthvgb1/wp-go/internal/pkg/constraints" "github.com/fthvgb1/wp-go/internal/pkg/constraints"
"github.com/fthvgb1/wp-go/internal/pkg/models" "github.com/fthvgb1/wp-go/internal/pkg/models"
"github.com/fthvgb1/wp-go/internal/wpconfig" "github.com/fthvgb1/wp-go/internal/wpconfig"
@ -32,11 +33,17 @@ func CategoriesAndTags(a ...any) (terms []models.TermsMy, err error) {
in = []any{"post_tag"} in = []any{"post_tag"}
} }
} }
w := model.SqlBuilder{
{"tt.taxonomy", "in", ""},
}
if helper.GetValFromContext(ctx, "onlyTop", false) {
w = append(w, []string{"tt.parent", "=", "0", "int"})
}
if !helper.GetValFromContext(ctx, "showCountZero", false) {
w = append(w, []string{"tt.count", ">", "0", "int"})
}
terms, err = model.Finds[models.TermsMy](ctx, model.Conditions( terms, err = model.Finds[models.TermsMy](ctx, model.Conditions(
model.Where(model.SqlBuilder{ model.Where(w),
{"tt.count", ">", "0", "int"},
{"tt.taxonomy", "in", ""},
}),
model.Fields("t.term_id"), model.Fields("t.term_id"),
model.Order(model.SqlBuilder{{"t.name", "asc"}}), model.Order(model.SqlBuilder{{"t.name", "asc"}}),
model.Join(model.SqlBuilder{ model.Join(model.SqlBuilder{

View File

@ -84,16 +84,27 @@ func categoryUL(h *wp.Handle, args map[string]string, conf map[any]any, categori
s := str.NewBuilder() s := str.NewBuilder()
s.WriteString("<ul>\n") s.WriteString("<ul>\n")
isCount := conf["count"].(int64) isCount := conf["count"].(int64)
currentCate := models.TermsMy{}
if h.Scene() == constraints.Category {
cat := h.C.Param("category")
_, currentCate = slice.SearchFirst(categories, func(my models.TermsMy) bool {
return cat == my.Name
})
}
if conf["hierarchical"].(int64) == 0 { if conf["hierarchical"].(int64) == 0 {
for _, category := range categories { for _, category := range categories {
count := "" count := ""
if isCount != 0 { if isCount != 0 {
count = fmt.Sprintf("(%d)", category.Count) count = fmt.Sprintf("(%d)", category.Count)
} }
s.Sprintf(` <li class="cat-item cat-item-%d"> current := ""
if category.TermTaxonomyId == currentCate.TermTaxonomyId {
current = "current-cat"
}
s.Sprintf(` <li class="cat-item cat-item-%d %s">
<a href="/p/category/%s">%s %s</a> <a href="/p/category/%s">%s %s</a>
</li> </li>
`, category.Terms.TermId, category.Name, category.Name, count) `, category.Terms.TermId, current, category.Name, category.Name, count)
} }
} else { } else {
@ -101,16 +112,9 @@ func categoryUL(h *wp.Handle, args map[string]string, conf map[any]any, categori
return cate.TermTaxonomyId, cate.Parent return cate.TermTaxonomyId, cate.Parent
}) })
cate := &tree.Node[models.TermsMy, uint64]{Data: models.TermsMy{}} cate := &tree.Node[models.TermsMy, uint64]{Data: models.TermsMy{}}
if h.Scene() == constraints.Category { if currentCate.TermTaxonomyId > 0 {
cat := h.C.Param("category") cate = m[currentCate.TermTaxonomyId]
i, ca := slice.SearchFirst(categories, func(my models.TermsMy) bool {
return cat == my.Name
})
if i > 0 {
cate = m[ca.TermTaxonomyId]
}
} }
r := m[0] r := m[0]
categoryLi(r, cate, tree.Ancestor(m, 0, cate), isCount, s) categoryLi(r, cate, tree.Ancestor(m, 0, cate), isCount, s)
} }

View File

@ -2,6 +2,7 @@ package model
import ( import (
"context" "context"
"github.com/fthvgb1/wp-go/helper"
"github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/number"
str "github.com/fthvgb1/wp-go/helper/strings" str "github.com/fthvgb1/wp-go/helper/strings"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
@ -64,20 +65,12 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page
return return
} }
q.Offset = offset q.Offset = offset
m := ctx.Value("handle=>toMap") m := helper.GetValFromContext[*[]map[string]string](ctx, "handle=>toMap", nil)
if m == nil { if m == nil {
r, err = finds[T](db, ctx, q) r, err = finds[T](db, ctx, q)
return return
} }
mm, ok := m.(*[]map[string]string) *m, err = findToStringMap[T](db, ctx, q)
if ok {
mx, er := findToStringMap[T](db, ctx, q)
if er != nil {
err = er
return
}
*mm = mx
}
return return
} }

View File

@ -3,6 +3,7 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/fthvgb1/wp-go/helper"
"github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/helper/slice"
str "github.com/fthvgb1/wp-go/helper/strings" str "github.com/fthvgb1/wp-go/helper/strings"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -35,31 +36,26 @@ func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error)
} }
func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error { func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error {
v := ctx.Value("handle=>") v := helper.GetValFromContext(ctx, "handle=>", "")
if v != nil { if v != "" {
vv, ok := v.(string) switch v {
if ok && vv != "" { case "string":
switch vv { return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...)
case "string": case "scanner":
return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) fn := ctx.Value("fn")
case "scanner": return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any)))
fn := ctx.Value("fn")
return Scanner[any](r.sqlx, dest, sql, params...)(fn.(func(any)))
}
} }
} }
return r.sqlx.Select(dest, sql, params...) return r.sqlx.Select(dest, sql, params...)
} }
func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error {
v := ctx.Value("handle=>") v := helper.GetValFromContext(ctx, "handle=>", "")
if v != nil { if v != "" {
vv, ok := v.(string) switch v {
if ok && vv != "" { case "string":
switch vv { return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...)
case "string":
return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...)
}
} }
} }
return r.sqlx.Get(dest, sql, params...) return r.sqlx.Get(dest, sql, params...)