diff --git a/actions/comment.go b/actions/comment.go index dc955e6..9b228ce 100644 --- a/actions/comment.go +++ b/actions/comment.go @@ -55,7 +55,7 @@ func PostComment(c *gin.Context) { for _, cookie := range res.Request.Response.Cookies() { c.SetCookie(cookie.Name, cookie.Value, cookie.MaxAge, cookie.Path, cookie.Domain, cookie.Secure, cookie.HttpOnly) } - c.Redirect(http.StatusFound, res.Request.Response.Header.Get("Location")) + //c.Redirect(http.StatusFound, res.Request.Response.Header.Get("Location")) cc := c.Copy() go func() { id, err := strconv.ParseUint(i, 10, 64) @@ -70,8 +70,13 @@ func PostComment(c *gin.Context) { } su := fmt.Sprintf("%s: %s[%s]发表了评论对文档[%v]的评论", wp.Option["siteurl"], author, m, post.PostTitle) err = mail.SendMail([]string{config.Conf.Mail.User}, su, comment) - logs.ErrPrintln(err, "发送邮件") + logs.ErrPrintln(err, "发送邮件", config.Conf.Mail.User, su, comment) }() + s, err := io.ReadAll(res.Body) + if err != nil { + return + } + c.String(http.StatusOK, string(s)) return } s, err := io.ReadAll(res.Body) diff --git a/actions/common/comments.go b/actions/common/comments.go index 7c3dfbd..76ddd06 100644 --- a/actions/common/comments.go +++ b/actions/common/comments.go @@ -11,15 +11,16 @@ import ( ) func RecentComments(ctx context.Context, n int) (r []wp.Comments) { - r, err := recentCommentsCaches.GetCache(ctx, time.Second) + r, err := recentCommentsCaches.GetCache(ctx, time.Second, ctx) if len(r) > n { r = r[0:n] } logs.ErrPrintln(err, "get recent comment") return } -func recentComments(...any) (r []wp.Comments, err error) { - return models.Find[wp.Comments](models.SqlBuilder{ +func recentComments(a ...any) (r []wp.Comments, err error) { + ctx := a[0].(context.Context) + return models.Find[wp.Comments](ctx, models.SqlBuilder{ {"comment_approved", "1"}, {"post_status", "publish"}, }, "comment_ID,comment_author,comment_post_ID,post_title", "", models.SqlBuilder{{"comment_date_gmt", "desc"}}, models.SqlBuilder{ @@ -28,7 +29,7 @@ func recentComments(...any) (r []wp.Comments, err error) { } func PostComments(ctx context.Context, Id uint64) ([]wp.Comments, error) { - ids, err := postCommentCaches.GetCache(ctx, Id, time.Second, Id) + ids, err := postCommentCaches.GetCache(ctx, Id, time.Second, ctx, Id) if err != nil { return nil, err } @@ -36,8 +37,9 @@ func PostComments(ctx context.Context, Id uint64) ([]wp.Comments, error) { } func postComments(args ...any) ([]uint64, error) { - postId := args[0].(uint64) - r, err := models.Find[wp.Comments](models.SqlBuilder{ + ctx := args[0].(context.Context) + postId := args[1].(uint64) + r, err := models.Find[wp.Comments](ctx, models.SqlBuilder{ {"comment_approved", "1"}, {"comment_post_ID", "=", strconv.FormatUint(postId, 10), "int"}, }, "comment_ID", "", models.SqlBuilder{ @@ -53,17 +55,18 @@ func postComments(args ...any) ([]uint64, error) { } func GetCommentById(ctx context.Context, id uint64) (wp.Comments, error) { - return commentsCache.GetCache(ctx, id, time.Second, id) + return commentsCache.GetCache(ctx, id, time.Second, ctx, id) } func GetCommentByIds(ctx context.Context, ids []uint64) ([]wp.Comments, error) { - return commentsCache.GetCacheBatch(ctx, ids, time.Second, ids) + return commentsCache.GetCacheBatch(ctx, ids, time.Second, ctx, ids) } func getCommentByIds(args ...any) (map[uint64]wp.Comments, error) { - ids := args[0].([]uint64) + ctx := args[0].(context.Context) + ids := args[1].([]uint64) m := make(map[uint64]wp.Comments) - r, err := models.SimpleFind[wp.Comments](models.SqlBuilder{ + r, err := models.SimpleFind[wp.Comments](ctx, models.SqlBuilder{ {"comment_ID", "in", ""}, {"comment_approved", "1"}, }, "*", helper.SliceMap(ids, helper.ToAny[uint64])) if err != nil { diff --git a/actions/common/common.go b/actions/common/common.go index 0d9d981..96b939d 100644 --- a/actions/common/common.go +++ b/actions/common/common.go @@ -77,15 +77,15 @@ type PostIds struct { type Arch struct { data []wp.PostArchive mutex *sync.Mutex - setCacheFunc func() ([]wp.PostArchive, error) + setCacheFunc func(context.Context) ([]wp.PostArchive, error) month time.Month } -func (c *Arch) getArchiveCache() []wp.PostArchive { +func (c *Arch) getArchiveCache(ctx context.Context) []wp.PostArchive { l := len(c.data) m := time.Now().Month() if l > 0 && c.month != m || l < 1 { - r, err := c.setCacheFunc() + r, err := c.setCacheFunc(ctx) if err != nil { logs.ErrPrintln(err, "set cache err[%s]") return nil @@ -103,25 +103,26 @@ type PostContext struct { next wp.Posts } -func archives() ([]wp.PostArchive, error) { - return models.Find[wp.PostArchive](models.SqlBuilder{ +func archives(ctx context.Context) ([]wp.PostArchive, error) { + return models.Find[wp.PostArchive](ctx, models.SqlBuilder{ {"post_type", "post"}, {"post_status", "publish"}, }, "YEAR(post_date) AS `year`, MONTH(post_date) AS `month`, count(ID) as posts", "year,month", models.SqlBuilder{{"year", "desc"}, {"month", "desc"}}, nil, nil, 0) } -func Archives() (r []wp.PostArchive) { - return archivesCaches.getArchiveCache() +func Archives(ctx context.Context) (r []wp.PostArchive) { + return archivesCaches.getArchiveCache(ctx) } func Categories(ctx context.Context) []wp.WpTermsMy { - r, err := categoryCaches.GetCache(ctx, time.Second) + r, err := categoryCaches.GetCache(ctx, time.Second, ctx) logs.ErrPrintln(err, "get category ") return r } -func categories(...any) (terms []wp.WpTermsMy, err error) { +func categories(a ...any) (terms []wp.WpTermsMy, err error) { + ctx := a[0].(context.Context) var in = []any{"category"} - terms, err = models.Find[wp.WpTermsMy](models.SqlBuilder{ + terms, err = models.Find[wp.WpTermsMy](ctx, models.SqlBuilder{ {"tt.count", ">", "0", "int"}, {"tt.taxonomy", "in", ""}, }, "t.term_id", "", models.SqlBuilder{ diff --git a/actions/common/posts.go b/actions/common/posts.go index 0d24059..bd718f0 100644 --- a/actions/common/posts.go +++ b/actions/common/posts.go @@ -14,11 +14,11 @@ import ( ) func GetPostById(ctx context.Context, id uint64) (wp.Posts, error) { - return postsCache.GetCache(ctx, id, time.Second, id) + return postsCache.GetCache(ctx, id, time.Second, ctx, id) } func GetPostsByIds(ctx context.Context, ids []uint64) ([]wp.Posts, error) { - return postsCache.GetCacheBatch(ctx, ids, time.Second, ids) + return postsCache.GetCacheBatch(ctx, ids, time.Second, ctx, ids) } func SearchPost(ctx context.Context, key string, args ...any) (r []wp.Posts, total int, err error) { @@ -32,10 +32,11 @@ func SearchPost(ctx context.Context, key string, args ...any) (r []wp.Posts, tot } func getPostsByIds(ids ...any) (m map[uint64]wp.Posts, err error) { + ctx := ids[0].(context.Context) m = make(map[uint64]wp.Posts) - id := ids[0].([]uint64) + id := ids[1].([]uint64) arg := helper.SliceMap(id, helper.ToAny[uint64]) - rawPosts, err := models.Find[wp.Posts](models.SqlBuilder{{ + rawPosts, err := models.Find[wp.Posts](ctx, models.SqlBuilder{{ "Id", "in", "", }}, "a.*,ifnull(d.name,'') category_name,ifnull(taxonomy,'') `taxonomy`", "", nil, models.SqlBuilder{{ "a", "left join", "wp_term_relationships b", "a.Id=b.object_id", @@ -91,14 +92,15 @@ func PostLists(ctx context.Context, key string, args ...any) (r []wp.Posts, tota } func searchPostIds(args ...any) (ids PostIds, err error) { - where := args[0].(models.SqlBuilder) - page := args[1].(int) - limit := args[2].(int) - order := args[3].(models.SqlBuilder) - join := args[4].(models.SqlBuilder) - postType := args[5].([]any) - postStatus := args[6].([]any) - res, total, err := models.SimplePagination[wp.Posts](where, "ID", "", page, limit, order, join, nil, postType, postStatus) + ctx := args[0].(context.Context) + where := args[1].(models.SqlBuilder) + page := args[2].(int) + limit := args[3].(int) + order := args[4].(models.SqlBuilder) + join := args[5].(models.SqlBuilder) + postType := args[6].([]any) + postStatus := args[7].([]any) + res, total, err := models.SimplePagination[wp.Posts](ctx, where, "ID", "", page, limit, order, join, nil, postType, postStatus) for _, posts := range res { ids.Ids = append(ids.Ids, posts.Id) } @@ -109,8 +111,9 @@ func searchPostIds(args ...any) (ids PostIds, err error) { return } -func getMaxPostId(...any) ([]uint64, error) { - r, err := models.SimpleFind[wp.Posts](models.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, "max(ID) ID") +func getMaxPostId(a ...any) ([]uint64, error) { + ctx := a[0].(context.Context) + r, err := models.SimpleFind[wp.Posts](ctx, models.SqlBuilder{{"post_type", "post"}, {"post_status", "publish"}}, "max(ID) ID") var id uint64 if len(r) > 0 { id = r[0].Id @@ -119,20 +122,21 @@ func getMaxPostId(...any) ([]uint64, error) { } func GetMaxPostId(ctx *gin.Context) (uint64, error) { - Id, err := maxPostIdCache.GetCache(ctx, time.Second) + Id, err := maxPostIdCache.GetCache(ctx, time.Second, ctx) return Id[0], err } func RecentPosts(ctx context.Context, n int) (r []wp.Posts) { - r, err := recentPostsCaches.GetCache(ctx, time.Second) + r, err := recentPostsCaches.GetCache(ctx, time.Second, ctx) if n < len(r) { r = r[:n] } logs.ErrPrintln(err, "get recent post") return } -func recentPosts(...any) (r []wp.Posts, err error) { - r, err = models.Find[wp.Posts](models.SqlBuilder{{ +func recentPosts(a ...any) (r []wp.Posts, err error) { + ctx := a[0].(context.Context) + r, err = models.Find[wp.Posts](ctx, models.SqlBuilder{{ "post_type", "post", }, {"post_status", "publish"}}, "ID,post_title,post_password", "", models.SqlBuilder{{"post_date", "desc"}}, nil, nil, 10) for i, post := range r { @@ -144,7 +148,7 @@ func recentPosts(...any) (r []wp.Posts, err error) { } func GetContextPost(ctx context.Context, id uint64, date time.Time) (prev, next wp.Posts, err error) { - postCtx, err := postContextCache.GetCache(ctx, id, time.Second, date) + postCtx, err := postContextCache.GetCache(ctx, id, time.Second, ctx, date) if err != nil { return wp.Posts{}, wp.Posts{}, err } @@ -154,8 +158,9 @@ func GetContextPost(ctx context.Context, id uint64, date time.Time) (prev, next } func getPostContext(arg ...any) (r PostContext, err error) { - t := arg[0].(time.Time) - next, err := models.FirstOne[wp.Posts](models.SqlBuilder{ + ctx := arg[0].(context.Context) + t := arg[1].(time.Time) + next, err := models.FirstOne[wp.Posts](ctx, models.SqlBuilder{ {"post_date", ">", t.Format("2006-01-02 15:04:05")}, {"post_status", "in", ""}, {"post_type", "post"}, @@ -166,7 +171,7 @@ func getPostContext(arg ...any) (r PostContext, err error) { if err != nil { return } - prev, err := models.FirstOne[wp.Posts](models.SqlBuilder{ + prev, err := models.FirstOne[wp.Posts](ctx, models.SqlBuilder{ {"post_date", "<", t.Format("2006-01-02 15:04:05")}, {"post_status", "in", ""}, {"post_type", "post"}, @@ -199,7 +204,8 @@ func GetMonthPostIds(ctx context.Context, year, month string, page, limit int, o } func monthPost(args ...any) (r []uint64, err error) { - year, month := args[0].(string), args[1].(string) + ctx := args[0].(context.Context) + year, month := args[1].(string), args[2].(string) where := models.SqlBuilder{ {"post_type", "in", ""}, {"post_status", "in", ""}, @@ -208,7 +214,7 @@ func monthPost(args ...any) (r []uint64, err error) { } postType := []any{"post"} status := []any{"publish"} - ids, err := models.Find[wp.Posts](where, "ID", "", models.SqlBuilder{{"Id", "asc"}}, nil, nil, 0, postType, status) + ids, err := models.Find[wp.Posts](ctx, where, "ID", "", models.SqlBuilder{{"Id", "asc"}}, nil, nil, 0, postType, status) if err != nil { return } diff --git a/actions/common/users.go b/actions/common/users.go index f2a69ca..390dc66 100644 --- a/actions/common/users.go +++ b/actions/common/users.go @@ -1,6 +1,7 @@ package common import ( + "context" "github.com/gin-gonic/gin" "github/fthvgb1/wp-go/logs" "github/fthvgb1/wp-go/models" @@ -8,9 +9,10 @@ import ( "time" ) -func getUsers(...any) (m map[uint64]wp.Users, err error) { +func getUsers(a ...any) (m map[uint64]wp.Users, err error) { m = make(map[uint64]wp.Users) - r, err := models.SimpleFind[wp.Users](nil, "*") + ctx := a[0].(context.Context) + r, err := models.SimpleFind[wp.Users](ctx, nil, "*") for _, user := range r { m[user.Id] = user } @@ -18,7 +20,7 @@ func getUsers(...any) (m map[uint64]wp.Users, err error) { } func GetUser(ctx *gin.Context, uid uint64) wp.Users { - r, err := usersCache.GetCache(ctx, uid, time.Second, uid) + r, err := usersCache.GetCache(ctx, uid, time.Second, ctx, uid) logs.ErrPrintln(err, "get user", uid) return r } diff --git a/actions/detail.go b/actions/detail.go index 47f169f..a1ac19e 100644 --- a/actions/detail.go +++ b/actions/detail.go @@ -28,7 +28,7 @@ func Detail(c *gin.Context) { c, } recent := common.RecentPosts(c, 5) - archive := common.Archives() + archive := common.Archives(c) categoryItems := common.Categories(c) recentComments := common.RecentComments(c, 5) var h = gin.H{ diff --git a/actions/index.go b/actions/index.go index 5e6a7ba..dd05c6a 100644 --- a/actions/index.go +++ b/actions/index.go @@ -162,7 +162,7 @@ func (h *indexHandle) getTotalPage(totalRaws int) int { func Index(c *gin.Context) { h := newIndexHandle(c) h.parseParams() - archive := common.Archives() + archive := common.Archives(c) recent := common.RecentPosts(c, 5) categoryItems := common.Categories(c) recentComments := common.RecentComments(c, 5) @@ -185,9 +185,9 @@ func Index(c *gin.Context) { return } } else if h.search != "" { - postIds, totalRaw, err = common.SearchPost(c, h.getSearchKey(), h.where, h.page, h.pageSize, models.SqlBuilder{{h.orderBy, h.order}}, h.join, h.postType, h.status) + postIds, totalRaw, err = common.SearchPost(c, h.getSearchKey(), c, h.where, h.page, h.pageSize, models.SqlBuilder{{h.orderBy, h.order}}, h.join, h.postType, h.status) } else { - postIds, totalRaw, err = common.PostLists(c, h.getSearchKey(), h.where, h.page, h.pageSize, models.SqlBuilder{{h.orderBy, h.order}}, h.join, h.postType, h.status) + postIds, totalRaw, err = common.PostLists(c, h.getSearchKey(), c, h.where, h.page, h.pageSize, models.SqlBuilder{{h.orderBy, h.order}}, h.join, h.postType, h.status) } defer func() { diff --git a/cache/map.go b/cache/map.go index 7cbad98..4da382c 100644 --- a/cache/map.go +++ b/cache/map.go @@ -48,8 +48,18 @@ func (m *MapCache[K, V]) SetCacheBatchFunc(fn func(...any) (map[K]V, error)) { func (m *MapCache[K, V]) setCacheFn(fn func(...any) (map[K]V, error)) { m.cacheFunc = func(a ...any) (V, error) { - id := a[0].(K) - r, err := fn([]K{id}) + var err error + var r map[K]V + var id K + ctx, ok := a[0].(context.Context) + if ok { + id = a[1].(K) + r, err = fn(ctx, []K{id}) + } else { + id = a[0].(K) + r, err = fn([]K{id}) + } + if err != nil { var rr V return rr, err diff --git a/db/db.go b/db/db.go index dcb3736..306444c 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "github/fthvgb1/wp-go/config" @@ -8,6 +9,22 @@ import ( var Db *sqlx.DB +type SqlxDb struct { + sqlx *sqlx.DB +} + +func NewSqlxDb(sqlx *sqlx.DB) *SqlxDb { + return &SqlxDb{sqlx: sqlx} +} + +func (r SqlxDb) Select(ctx context.Context, dest any, sql string, params ...any) error { + return r.sqlx.Select(dest, sql, params...) +} + +func (r SqlxDb) Get(ctx context.Context, dest any, sql string, params ...any) error { + return r.sqlx.Get(dest, sql, params...) +} + func InitDb() error { dsn := config.Conf.Mysql.Dsn.GetDsn() var err error diff --git a/main.go b/main.go index 648dc9a..489fa3f 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github/fthvgb1/wp-go/actions/common" "github/fthvgb1/wp-go/config" "github/fthvgb1/wp-go/db" + "github/fthvgb1/wp-go/models" "github/fthvgb1/wp-go/models/wp" "github/fthvgb1/wp-go/plugins" "github/fthvgb1/wp-go/route" @@ -27,7 +28,7 @@ func init() { if err != nil { panic(err) } - + models.InitDB(db.NewSqlxDb(db.Db)) err = wp.InitOptions() if err != nil { panic(err) diff --git a/models/model.go b/models/model.go index 9185620..e147a06 100644 --- a/models/model.go +++ b/models/model.go @@ -1,6 +1,7 @@ package models import ( + "context" "fmt" "github/fthvgb1/wp-go/helper" "strconv" @@ -8,6 +9,11 @@ import ( ) var _ ParseWhere = SqlBuilder{} +var globalBb dbQuery + +func InitDB(db dbQuery) { + globalBb = db +} type Model interface { PrimaryKey() string @@ -18,6 +24,11 @@ type ParseWhere interface { ParseWhere(*[][]any) (string, []any, error) } +type dbQuery interface { + Select(context.Context, any, string, ...any) error + Get(context.Context, any, string, ...any) error +} + type SqlBuilder [][]string func (w SqlBuilder) parseField(ss []string, s *strings.Builder) { diff --git a/models/query.go b/models/query.go index 3cc65aa..8fa01ef 100644 --- a/models/query.go +++ b/models/query.go @@ -1,15 +1,15 @@ package models import ( + "context" "fmt" - "github/fthvgb1/wp-go/db" "github/fthvgb1/wp-go/helper" "math/rand" "strings" "time" ) -func SimplePagination[T Model](where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { +func SimplePagination[T Model](ctx context.Context, where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, having SqlBuilder, in ...[]any) (r []T, total int, err error) { var rr T var w string var args []any @@ -56,12 +56,12 @@ func SimplePagination[T Model](where ParseWhere, fields, group string, page, pag if group == "" { tpx := "select count(*) n from %s %s %s limit 1" sq := fmt.Sprintf(tpx, rr.Table(), j, w) - err = db.Db.Get(&n, sq, args...) + err = globalBb.Get(ctx, &n, sq, args...) } else { tpx := "select count(*) n from (select %s from %s %s %s %s %s ) %s" rand.Seed(int64(time.Now().Nanosecond())) sq := fmt.Sprintf(tpx, group, rr.Table(), j, w, groupBy, h, fmt.Sprintf("table%d", rand.Int())) - err = db.Db.Get(&n, sq, args...) + err = globalBb.Get(ctx, &n, sq, args...) } if err != nil { @@ -80,24 +80,24 @@ func SimplePagination[T Model](where ParseWhere, fields, group string, page, pag } tp := "select %s from %s %s %s %s %s %s limit %d,%d" sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), offset, pageSize) - err = db.Db.Select(&r, sql, args...) + err = globalBb.Select(ctx, &r, sql, args...) if err != nil { return } return } -func FindOneById[T Model, I helper.IntNumber](id I) (T, error) { +func FindOneById[T Model, I helper.IntNumber](ctx context.Context, id I) (T, error) { var r T sql := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) - err := db.Db.Get(&r, sql, id) + err := globalBb.Get(ctx, &r, sql, id) if err != nil { return r, err } return r, nil } -func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ...[]any) (T, error) { +func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (T, error) { var r T var w string var args []any @@ -110,14 +110,14 @@ func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ... } tp := "select %s from %s %s %s" sql := fmt.Sprintf(tp, fields, r.Table(), w, order.parseOrderBy()) - err = db.Db.Get(&r, sql, args...) + err = globalBb.Get(ctx, &r, sql, args...) if err != nil { return r, err } return r, nil } -func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { +func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) { var r T var w string var args []any @@ -130,14 +130,14 @@ func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { } tp := "select %s from %s %s order by %s desc limit 1" sql := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) - err = db.Db.Get(&r, sql, args...) + err = globalBb.Get(ctx, &r, sql, args...) if err != nil { return r, err } return r, nil } -func SimpleFind[T Model](where ParseWhere, fields string, in ...[]any) ([]T, error) { +func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) ([]T, error) { var r []T var rr T var err error @@ -151,25 +151,25 @@ func SimpleFind[T Model](where ParseWhere, fields string, in ...[]any) ([]T, err } tp := "select %s from %s %s" sql := fmt.Sprintf(tp, fields, rr.Table(), w) - err = db.Db.Select(&r, sql, args...) + err = globalBb.Select(ctx, &r, sql, args...) if err != nil { return r, err } return r, nil } -func Select[T Model](sql string, params ...any) ([]T, error) { +func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error) { var r []T var rr T sql = strings.Replace(sql, "{table}", rr.Table(), -1) - err := db.Db.Select(&r, sql, params...) + err := globalBb.Select(ctx, &r, sql, params...) if err != nil { return r, err } return r, nil } -func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { +func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, order SqlBuilder, join SqlBuilder, having SqlBuilder, limit int, in ...[]any) (r []T, err error) { var rr T w := "" var args []any @@ -203,12 +203,12 @@ func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, joi l = fmt.Sprintf(" limit %d", limit) } sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, groupBy, h, order.parseOrderBy(), l) - err = db.Db.Select(&r, sql, args...) + err = globalBb.Select(ctx, &r, sql, args...) return } -func Get[T Model](sql string, params ...any) (r T, err error) { +func Get[T Model](ctx context.Context, sql string, params ...any) (r T, err error) { sql = strings.Replace(sql, "{table}", r.Table(), -1) - err = db.Db.Get(&r, sql, params...) + err = globalBb.Get(ctx, &r, sql, params...) return } diff --git a/models/wp/globalInit.go b/models/wp/globalInit.go index 803db1a..dabc21b 100644 --- a/models/wp/globalInit.go +++ b/models/wp/globalInit.go @@ -1,18 +1,22 @@ package wp -import "github/fthvgb1/wp-go/models" +import ( + "context" + "github/fthvgb1/wp-go/models" +) var Option = make(map[string]string) var Terms = map[uint64]WpTerms{} var TermTaxonomies = map[uint64]TermTaxonomy{} func InitOptions() error { - ops, err := models.SimpleFind[Options](models.SqlBuilder{{"autoload", "yes"}}, "option_name, option_value") + ctx := context.Background() + ops, err := models.SimpleFind[Options](ctx, models.SqlBuilder{{"autoload", "yes"}}, "option_name, option_value") if err != nil { return err } if len(ops) == 0 { - ops, err = models.SimpleFind[Options](nil, "option_name, option_value") + ops, err = models.SimpleFind[Options](ctx, nil, "option_name, option_value") if err != nil { return err } @@ -24,14 +28,15 @@ func InitOptions() error { } func InitTerms() (err error) { - terms, err := models.SimpleFind[WpTerms](nil, "*") + ctx := context.Background() + terms, err := models.SimpleFind[WpTerms](ctx, nil, "*") if err != nil { return err } for _, wpTerms := range terms { Terms[wpTerms.TermId] = wpTerms } - termTax, err := models.SimpleFind[TermTaxonomy](nil, "*") + termTax, err := models.SimpleFind[TermTaxonomy](ctx, nil, "*") if err != nil { return err }