优化代码 调整 model query

This commit is contained in:
xing 2023-02-06 17:58:24 +08:00
parent dc203f87b8
commit f28c41c84a
8 changed files with 143 additions and 125 deletions

View File

@ -55,11 +55,11 @@ func initConf(c string) (err error) {
return return
} }
err = db.InitDb() database, err := db.InitDb()
if err != nil { if err != nil {
return return
} }
model.InitDB(db.NewSqlxDb(db.Db)) model.InitDB(model.NewSqlxQuery(database))
err = wpconfig.InitOptions() err = wpconfig.InitOptions()
if err != nil { if err != nil {
return return

View File

@ -8,7 +8,7 @@ import (
) )
var path = map[string]struct{}{ var path = map[string]struct{}{
"includes": {}, "wp-includes": {},
"wp-content": {}, "wp-content": {},
"favicon.ico": {}, "favicon.ico": {},
} }

View File

@ -13,20 +13,21 @@ import (
"time" "time"
) )
func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) { func GetPostsByIds(a ...any) (m map[uint64]models.Posts, err error) {
ctx := ids[0].(context.Context) ctx := a[0].(context.Context)
m = make(map[uint64]models.Posts) m = make(map[uint64]models.Posts)
id := ids[1].([]uint64) ids := a[1].([]uint64)
arg := slice.ToAnySlice(id) rawPosts, err := model.Finds[models.Posts](ctx, model.Conditions(
rawPosts, err := model.Find[models.Posts](ctx, model.SqlBuilder{{ model.Where(model.SqlBuilder{{"Id", "in", ""}}),
"Id", "in", "", model.Join(model.SqlBuilder{
}}, "a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`", "", nil, model.SqlBuilder{{ {"a", "left join", "wp_term_relationships b", "a.Id=b.object_id"},
"a", "left join", "wp_term_relationships b", "a.Id=b.object_id", {"left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id"},
}, { {"left join", "wp_terms d", "c.term_id=d.term_id"},
"left join", "wp_term_taxonomy c", "b.term_taxonomy_id=c.term_taxonomy_id", }),
}, { model.Fields("a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`"),
"left join", "wp_terms d", "c.term_id=d.term_id", model.In(slice.ToAnySlice(ids)),
}}, nil, 0, arg) ))
if err != nil { if err != nil {
return m, err return m, err
} }
@ -45,7 +46,7 @@ func GetPostsByIds(ids ...any) (m map[uint64]models.Posts, err error) {
} }
//host, _ := wpconfig.Options.Load("siteurl") //host, _ := wpconfig.Options.Load("siteurl")
host := "" host := ""
meta, _ := GetPostMetaByPostIds(ctx, id) meta, _ := GetPostMetaByPostIds(ctx, ids)
for k, pp := range postsMap { for k, pp := range postsMap {
if len(pp.Categories) > 0 { if len(pp.Categories) > 0 {
t := make([]string, 0, len(pp.Categories)) t := make([]string, 0, len(pp.Categories))

View File

@ -1,78 +1,32 @@
package db package db
import ( import (
"context"
"fmt"
"github.com/fthvgb1/wp-go/internal/pkg/config" "github.com/fthvgb1/wp-go/internal/pkg/config"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"log"
"os"
"strconv"
"strings"
) )
var Db *sqlx.DB var db *sqlx.DB
type SqlxDb struct { func InitDb() (*sqlx.DB, error) {
sqlx *sqlx.DB
}
func NewSqlxDb(sqlx *sqlx.DB) *SqlxDb {
return &SqlxDb{sqlx: sqlx}
}
func (r SqlxDb) Select(ctx context.Context, dest any, sql string, params ...any) error {
if os.Getenv("SHOW_SQL") == "true" {
go log.Println(formatSql(sql, params))
}
return r.sqlx.Select(dest, sql, params...)
}
func (r SqlxDb) Get(ctx context.Context, dest any, sql string, params ...any) error {
if os.Getenv("SHOW_SQL") == "true" {
go log.Println(formatSql(sql, params))
}
return r.sqlx.Get(dest, sql, params...)
}
func formatSql(sql string, params []any) string {
for _, param := range params {
switch param.(type) {
case string:
sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1)
case int64:
sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1)
case int:
sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1)
case uint64:
sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1)
case float64:
sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1)
}
}
return sql
}
func InitDb() error {
c := config.GetConfig() c := config.GetConfig()
dsn := c.Mysql.Dsn.GetDsn() dsn := c.Mysql.Dsn.GetDsn()
var err error var err error
Db, err = sqlx.Open("mysql", dsn) db, err = sqlx.Open("mysql", dsn)
if err != nil { if err != nil {
return err return nil, err
} }
if c.Mysql.Pool.ConnMaxIdleTime != 0 { if c.Mysql.Pool.ConnMaxIdleTime != 0 {
Db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime) db.SetConnMaxIdleTime(c.Mysql.Pool.ConnMaxLifetime)
} }
if c.Mysql.Pool.MaxIdleConn != 0 { if c.Mysql.Pool.MaxIdleConn != 0 {
Db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn) db.SetMaxIdleConns(c.Mysql.Pool.MaxIdleConn)
} }
if c.Mysql.Pool.MaxOpenConn != 0 { if c.Mysql.Pool.MaxOpenConn != 0 {
Db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn) db.SetMaxOpenConns(c.Mysql.Pool.MaxOpenConn)
} }
if c.Mysql.Pool.ConnMaxLifetime != 0 { if c.Mysql.Pool.ConnMaxLifetime != 0 {
Db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime) db.SetConnMaxLifetime(c.Mysql.Pool.ConnMaxLifetime)
} }
return err return db, err
} }

View File

@ -8,7 +8,7 @@ import (
"strings" "strings"
) )
func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { func pagination[T Model](db dbQuery, ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) {
var rr T var rr T
var w string var w string
var args []any var args []any
@ -55,11 +55,11 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr
if group == "" { if group == "" {
tpx := "select count(*) n from %s %s %s limit 1" tpx := "select count(*) n from %s %s %s limit 1"
sq := fmt.Sprintf(tpx, rr.Table(), j, w) sq := fmt.Sprintf(tpx, rr.Table(), j, w)
err = globalBb.Get(ctx, &n, sq, args...) err = db.Get(ctx, &n, sq, args...)
} else { } else {
tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s"
sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int()))
err = globalBb.Get(ctx, &n, sq, args...) err = db.Get(ctx, &n, sq, args...)
} }
if err != nil { if err != nil {
@ -78,13 +78,18 @@ func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, gr
} }
tp := "select %s from %s %s %s %s %s %s limit %d,%d" tp := "select %s from %s %s %s %s %s %s limit %d,%d"
sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) sq := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize)
err = globalBb.Select(ctx, &r, sq, args...) err = db.Select(ctx, &r, sq, args...)
if err != nil { if err != nil {
return return
} }
return return
} }
func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) {
r, total, err = pagination[T](globalBb, ctx, where, fields, group, page, pageSize, order, join, having, in...)
return
}
func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) {
var r T var r T
sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey())

View File

@ -3,15 +3,11 @@ package model
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/fthvgb1/wp-go/helper/number" "github.com/fthvgb1/wp-go/helper/number"
"github.com/fthvgb1/wp-go/helper/slice" "github.com/fthvgb1/wp-go/helper/slice"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"log"
"reflect" "reflect"
"strconv"
"strings"
"testing" "testing"
"time" "time"
) )
@ -102,40 +98,6 @@ func (p post) Table() string {
return "wp_posts" return "wp_posts"
} }
type SqlxDb struct {
sqlx *sqlx.DB
}
var Db *SqlxDb
func (r SqlxDb) Select(_ context.Context, dest any, sql string, params ...any) error {
log.Println(formatSql(sql, params))
return r.sqlx.Select(dest, sql, params...)
}
func (r SqlxDb) Get(_ context.Context, dest any, sql string, params ...any) error {
log.Println(formatSql(sql, params))
return r.sqlx.Get(dest, sql, params...)
}
func formatSql(sql string, params []any) string {
for _, param := range params {
switch param.(type) {
case string:
sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1)
case int64:
sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1)
case int:
sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1)
case uint64:
sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1)
case float64:
sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1)
}
}
return sql
}
var ctx = context.Background() var ctx = context.Background()
func init() { func init() {
@ -143,8 +105,7 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
Db = &SqlxDb{db} InitDB(NewSqlxQuery(db))
InitDB(Db)
} }
func TestFind(t *testing.T) { func TestFind(t *testing.T) {
type args struct { type args struct {

View File

@ -11,6 +11,19 @@ import (
// //
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) { func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](globalBb, ctx, q)
return
}
// DBFind 同 Finds 使用指定 db 查询
//
// Conditions 中可用 Where Fields Group Having Join Order Offset Limit In 函数
func DBFind[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
r, err = finds[T](db, ctx, q)
return
}
func finds[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []T, err error) {
var rr T var rr T
w := "" w := ""
var args []any var args []any
@ -48,25 +61,22 @@ func Finds[T Model](ctx context.Context, q *QueryCondition) (r []T, err error) {
l = fmt.Sprintf(" %s offset %d", l, q.offset) l = fmt.Sprintf(" %s offset %d", l, q.offset)
} }
sq := fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) sq := fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l)
err = globalBb.Select(ctx, &r, sq, args...) err = db.Select(ctx, &r, sq, args...)
return return
} }
// ChunkFind 分片查询并直接返回所有结果 func chunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
i := 1 i := 1
var rr []T var rr []T
var total int var total int
var offset int var offset int
for { for {
if 1 == i { if 1 == i {
rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
} else { } else {
q.offset = offset q.offset = offset
q.limit = perLimit q.limit = perLimit
rr, err = Finds[T](ctx, q) rr, err = finds[T](db, ctx, q)
} }
offset += perLimit offset += perLimit
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
@ -81,10 +91,39 @@ func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r
return return
} }
// ChunkFind 分片查询并直接返回所有结果
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func ChunkFind[T Model](ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](globalBb, ctx, perLimit, q)
return
}
// DBChunkFind 同 ChunkFind
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func DBChunkFind[T Model](db dbQuery, ctx context.Context, perLimit int, q *QueryCondition) (r []T, err error) {
r, err = chunkFind[T](db, ctx, perLimit, q)
return
}
// Chunk 分片查询并函数过虑返回新类型的切片 // Chunk 分片查询并函数过虑返回新类型的切片
// //
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数 // Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) { func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(globalBb, ctx, perLimit, fn, q)
return
}
// DBChunk 同 Chunk
//
// Conditions 中可用 Where Fields Group Having Join Order Limit In 函数
func DBChunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
r, err = chunk(db, ctx, perLimit, fn, q)
return
}
func chunk[T Model, R any](db dbQuery, ctx context.Context, perLimit int, fn func(rows T) (R, bool), q *QueryCondition) (r []R, err error) {
i := 1 i := 1
var rr []T var rr []T
var count int var count int
@ -92,11 +131,11 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R
var offset int var offset int
for { for {
if 1 == i { if 1 == i {
rr, total, err = SimplePagination[T](ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...) rr, total, err = pagination[T](db, ctx, q.where, q.fields, q.group, i, perLimit, q.order, q.join, q.having, q.in...)
} else { } else {
q.offset = offset q.offset = offset
q.limit = perLimit q.limit = perLimit
rr, err = Finds[T](ctx, q) rr, err = finds[T](db, ctx, q)
} }
offset += perLimit offset += perLimit
if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 { if (err != nil && err != sql.ErrNoRows) || len(rr) < 1 {
@ -123,3 +162,10 @@ func Chunk[T Model, R any](ctx context.Context, perLimit int, fn func(rows T) (R
func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) { func Pagination[T Model](ctx context.Context, q *QueryCondition) ([]T, int, error) {
return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...) return SimplePagination[T](ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
} }
// DBPagination 同 Pagination 方便多个db使用
//
// Condition 中可使用 Where Fields Group Having Join Order Page Limit In 函数
func DBPagination[T Model](db dbQuery, ctx context.Context, q *QueryCondition) ([]T, int, error) {
return pagination[T](db, ctx, q.where, q.fields, q.group, q.page, q.limit, q.order, q.join, q.having, q.in...)
}

51
model/sqxquery.go Normal file
View File

@ -0,0 +1,51 @@
package model
import (
"context"
"fmt"
"github.com/jmoiron/sqlx"
"log"
"os"
"strconv"
"strings"
)
type SqlxQuery struct {
sqlx *sqlx.DB
}
func NewSqlxQuery(sqlx *sqlx.DB) SqlxQuery {
return SqlxQuery{sqlx: sqlx}
}
func (r SqlxQuery) Select(ctx context.Context, dest any, sql string, params ...any) error {
if os.Getenv("SHOW_SQL") == "true" {
go log.Println(formatSql(sql, params))
}
return r.sqlx.Select(dest, sql, params...)
}
func (r SqlxQuery) Get(ctx context.Context, dest any, sql string, params ...any) error {
if os.Getenv("SHOW_SQL") == "true" {
go log.Println(formatSql(sql, params))
}
return r.sqlx.Get(dest, sql, params...)
}
func formatSql(sql string, params []any) string {
for _, param := range params {
switch param.(type) {
case string:
sql = strings.Replace(sql, "?", fmt.Sprintf("'%s'", param.(string)), 1)
case int64:
sql = strings.Replace(sql, "?", strconv.FormatInt(param.(int64), 10), 1)
case int:
sql = strings.Replace(sql, "?", strconv.Itoa(param.(int)), 1)
case uint64:
sql = strings.Replace(sql, "?", strconv.FormatUint(param.(uint64), 10), 1)
case float64:
sql = strings.Replace(sql, "?", fmt.Sprintf("%f", param.(float64)), 1)
}
}
return sql
}