From f28c41c84afd873ee7b578cf382c5c5f831c1114 Mon Sep 17 00:00:00 2001 From: xing Date: Mon, 6 Feb 2023 17:58:24 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=20?= =?UTF-8?q?=E8=B0=83=E6=95=B4=20model=20query?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/cmd/main.go | 4 +- internal/middleware/staticFileCache.go | 2 +- internal/pkg/dao/posts.go | 29 ++++++------ internal/pkg/db/db.go | 64 ++++---------------------- model/query.go | 13 ++++-- model/query_test.go | 41 +---------------- model/querycondition.go | 64 ++++++++++++++++++++++---- model/sqxquery.go | 51 ++++++++++++++++++++ 8 files changed, 143 insertions(+), 125 deletions(-) create mode 100644 model/sqxquery.go diff --git a/internal/cmd/main.go b/internal/cmd/main.go index 4165dcd..a75b098 100644 --- a/internal/cmd/main.go +++ b/internal/cmd/main.go @@ -55,11 +55,11 @@ func initConf(c string) (err error) { return } - err = db.InitDb() + database, err := db.InitDb() if err != nil { return } - model.InitDB(db.NewSqlxDb(db.Db)) + model.InitDB(model.NewSqlxQuery(database)) err = wpconfig.InitOptions() if err != nil { return diff --git a/internal/middleware/staticFileCache.go b/internal/middleware/staticFileCache.go index 9ada8e2..993325b 100644 --- a/internal/middleware/staticFileCache.go +++ b/internal/middleware/staticFileCache.go @@ -8,7 +8,7 @@ import ( ) var path = map[string]struct{}{ - "includes": {}, + "wp-includes": {}, "wp-content": {}, "favicon.ico": {}, } diff --git a/internal/pkg/dao/posts.go b/internal/pkg/dao/posts.go index b92df99..e782230 100644 --- a/internal/pkg/dao/posts.go +++ b/internal/pkg/dao/posts.go @@ -13,20 +13,21 @@ import ( "time" ) -func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) { - ctx := ids[0].(context.Context) +func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) { + ctx := a[0].(context.Context) m = make(map[uint64]models.Posts) - id := ids[1].([]uint64) - arg := slice.ToAnySlice(id) - rawPosts, err := model.Find[models.Posts](ctx, model.SqlBuilder{{ - "Id", "in", "", - }}, "a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`", "", nil, model.SqlBuilder{{ - "a", "left join", "wp_term_relationships b", "a.Id=b.object_id", - }, { - "left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id", - }, { - "left join", "wp_terms d", "c.term_id=d.term_id", - }}, nil, 0, arg) + ids := a[1].([]uint64) + rawPosts, err := model.Finds[models.Posts](ctx, model.Conditions( + model.Where(model.SqlBuilder{{"Id", "in", ""}}), + model.Join(model.SqlBuilder{ + {"a", "left join", "wp_term_relationships b", "a.Id=b.object_id"}, + {"left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id"}, + {"left join", "wp_terms d", "c.term_id=d.term_id"}, + }), + model.Fields("a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`"), + model.In(slice.ToAnySlice(ids)), + )) + if err != nil { return m, err } @@ -45,7 +46,7 @@ func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) { } //host, _ := wpconfig.Options.Load("siteurl") host := "" - meta, _ := GetPostMetaByPostIds(ctx, id) + meta, _ := GetPostMetaByPostIds(ctx, ids) for k, pp := range postsMap { if len(pp.Categories) > 0 { t := make([]string, 0, len(pp.Categories)) diff --git a/internal/pkg/db/db.go b/internal/pkg/db/db.go index 7e64406..3907203 100644 --- a/internal/pkg/db/db.go +++ b/internal/pkg/db/db.go @@ -1,78 +1,32 @@ package db import ( - "context" - "fmt" "github.com/fthvgb1/wp-go/internal/pkg/config" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" - "log" - "os" - "strconv" - "strings" ) -var Db *sqlx.DB +var db *sqlx.DB -type SqlxDb struct { - sqlx *sqlx.DB -} - -func NewSqlxDb(sqlx *sqlx.DB) *SqlxDb { - return &SqlxDb{sqlx: sqlx} -} - -func (r SqlxDb) Select(ctx context.Context, dest any, sql string, params ...any) error { - if os.Getenv("SHOW_SQL") == "true" { - go log.Println(formatSql(sql, params)) - } - return r.sqlx.Select(dest, sql, params...) -} - -func (r SqlxDb) Get(ctx context.Context, dest any, sql string, params ...any) error { - if os.Getenv("SHOW_SQL") == "true" { - go log.Println(formatSql(sql, params)) - } - return r.sqlx.Get(dest, sql, params...) -} - -func formatSql(sql string, params []any) string { - for _, param := range params { - switch param.(type) { - case string: - sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) - case int64: - sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) - case int: - sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) - case uint64: - sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) - case float64: - sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) - } - } - return sql -} - -func InitDb() error { +func InitDb() (*sqlx.DB, error) { c := config.GetConfig() dsn := c.Mysql.Dsn.GetDsn() var err error - Db, err = sqlx.Open("mysql", dsn) + db, err = sqlx.Open("mysql", dsn) if err != nil { - return err + return nil, err } if c.Mysql.Pool.ConnMaxIdleTime != 0 { - Db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime) + db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime) } if c.Mysql.Pool.MaxIdleConn != 0 { - Db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn) + db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn) } if c.Mysql.Pool.MaxOpenConn != 0 { - Db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn) + db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn) } if c.Mysql.Pool.ConnMaxLifetime != 0 { - Db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime) + db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime) } - return err + return db, err } diff --git a/model/query.go b/model/query.go index 2f0278e..2373666 100644 --- a/model/query.go +++ b/model/query.go @@ -8,7 +8,7 @@ import ( "strings" ) -func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { +func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { var rr T var w string var args []any @@ -55,11 +55,11 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr if group == "" { tpx := "select count(*) n from %s %s %s limit 1" sq := fmt.Sprintf(tpx, rr.Table(), j, w) - err = globalBb.Get(ctx, &n, sq, args...) + err = db.Get(ctx, &n, sq, args...) } else { tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) - err = globalBb.Get(ctx, &n, sq, args...) + err = db.Get(ctx, &n, sq, args...) } if err != nil { @@ -78,13 +78,18 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr } tp := "select %s from %s %s %s %s %s %s limit %d,%d" sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) - err = globalBb.Select(ctx, &r, sq, args...) + err = db.Select(ctx, &r, sq, args...) if err != nil { return } return } +func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { + r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...) + return +} + func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { var r T sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) diff --git a/model/query_test.go b/model/query_test.go index 8e335a1..1aafe69 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -3,15 +3,11 @@ package model import ( "context" "database/sql" - "fmt" "github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/slice" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" - "log" "reflect" - "strconv" - "strings" "testing" "time" ) @@ -102,40 +98,6 @@ func (p post) Table() string { return "wp_posts" } -type SqlxDb struct { - sqlx *sqlx.DB -} - -var Db *SqlxDb - -func (r SqlxDb) Select(_ context.Context, dest any, sql string, params ...any) error { - log.Println(formatSql(sql, params)) - return r.sqlx.Select(dest, sql, params...) -} - -func (r SqlxDb) Get(_ context.Context, dest any, sql string, params ...any) error { - log.Println(formatSql(sql, params)) - return r.sqlx.Get(dest, sql, params...) -} - -func formatSql(sql string, params []any) string { - for _, param := range params { - switch param.(type) { - case string: - sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) - case int64: - sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) - case int: - sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) - case uint64: - sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) - case float64: - sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) - } - } - return sql -} - var ctx = context.Background() func init() { @@ -143,8 +105,7 @@ func init() { if err != nil { panic(err) } - Db = &SqlxDb{db} - InitDB(Db) + InitDB(NewSqlxQuery(db)) } func TestFind(t *testing.T) { type args struct { diff --git a/model/querycondition.go b/model/querycondition.go index 6b720a3..b5fbcbb 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -11,6 +11,19 @@ import ( // // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { + r, err = finds[T](globalBb, ctx, q) + return +} + +// DBFind 同 Finds 使用指定 db 查询 +// +// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 +func DBFind[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { + r, err = finds[T](db, ctx, q) + return +} + +func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) { var rr T w := "" var args []any @@ -48,25 +61,22 @@ func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { l = fmt.Sprintf(" %s offset %d", l, q.offset) } sq := fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) - err = globalBb.Select(ctx, &r, sq, args...) + err = db.Select(ctx, &r, sq, args...) return } -// ChunkFind 分片查询并直接返回所有结果 -// -// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 -func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { +func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { i := 1 var rr []T var total int var offset int for { if 1 == i { - rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) } else { q.offset = offset q.limit = perLimit - rr, err = Finds[T](ctx, q) + rr, err = finds[T](db, ctx, q) } offset += perLimit if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { @@ -81,10 +91,39 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r return } +// ChunkFind 分片查询并直接返回所有结果 +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { + r, err = chunkFind[T](globalBb, ctx, perLimit, q) + return +} + +// DBChunkFind 同 ChunkFind +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func DBChunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) { + r, err = chunkFind[T](db, ctx, perLimit, q) + return +} + // Chunk 分片查询并函数过虑返回新类型的切片 // // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { + r, err = chunk(globalBb, ctx, perLimit, fn, q) + return +} + +// DBChunk 同 Chunk +// +// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 +func DBChunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { + r, err = chunk(db, ctx, perLimit, fn, q) + return +} + +func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { i := 1 var rr []T var count int @@ -92,11 +131,11 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R var offset int for { if 1 == i { - rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) + rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) } else { q.offset = offset q.limit = perLimit - rr, err = Finds[T](ctx, q) + rr, err = finds[T](db, ctx, q) } offset += perLimit if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { @@ -123,3 +162,10 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) } + +// DBPagination 同 Pagination 方便多个db使用 +// +// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数 +func DBPagination[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) { + return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) +} diff --git a/model/sqxquery.go b/model/sqxquery.go new file mode 100644 index 0000000..9b7ac29 --- /dev/null +++ b/model/sqxquery.go @@ -0,0 +1,51 @@ +package model + +import ( + "context" + "fmt" + "github.com/jmoiron/sqlx" + "log" + "os" + "strconv" + "strings" +) + +type SqlxQuery struct { + sqlx *sqlx.DB +} + +func NewSqlxQuery(sqlx *sqlx.DB) SqlxQuery { + return SqlxQuery{sqlx: sqlx} +} + +func (r SqlxQuery) Select(ctx context.Context, dest any, sql string, params ...any) error { + if os.Getenv("SHOW_SQL") == "true" { + go log.Println(formatSql(sql, params)) + } + return r.sqlx.Select(dest, sql, params...) +} + +func (r SqlxQuery) Get(ctx context.Context, dest any, sql string, params ...any) error { + if os.Getenv("SHOW_SQL") == "true" { + go log.Println(formatSql(sql, params)) + } + return r.sqlx.Get(dest, sql, params...) +} + +func formatSql(sql string, params []any) string { + for _, param := range params { + switch param.(type) { + case string: + sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1) + case int64: + sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1) + case int: + sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1) + case uint64: + sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1) + case float64: + sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1) + } + } + return sql +}