diff --git a/app/pkg/db/db.go b/app/pkg/db/db.go index 2f6b507..789865e 100644 --- a/app/pkg/db/db.go +++ b/app/pkg/db/db.go @@ -15,6 +15,14 @@ import ( var safeDb = safety.NewVar[*sqlx.DB](nil) var showQuerySql func() bool +func GetSqlxDB() *sqlx.DB { + return safeDb.Load() +} + +func SetSqlxDB(db *sqlx.DB) { + safeDb.Store(db) +} + func InitDb() (*safety.Var[*sqlx.DB], error) { c := config.GetConfig() dsn := c.Mysql.Dsn.GetDsn() diff --git a/model/sqxquery.go b/model/sqxquery.go index c4ba7c4..e19b5fc 100644 --- a/model/sqxquery.go +++ b/model/sqxquery.go @@ -41,14 +41,14 @@ func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params .. if v != "" { switch v { case "string": - return ToMapSlice(db, dest.(*[]map[string]string), sql, params...) + return ToMapSlice(ctx, db, dest.(*[]map[string]string), sql, params...) case "scanner": fn := ctx.Value("fn") - return Scanner[any](db, dest, sql, params...)(fn.(func(any))) + return Scanner[any](ctx, db, dest, sql, params...)(fn.(func(any))) } } - return db.Select(dest, sql, params...) + return db.SelectContext(ctx, dest, sql, params...) } func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { @@ -57,15 +57,15 @@ func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...an if v != "" { switch v { case "string": - return GetToMap(db, dest.(*map[string]string), sql, params...) + return GetToMap(ctx, db, dest.(*map[string]string), sql, params...) } } - return db.Get(dest, sql, params...) + return db.GetContext(ctx, dest, sql, params...) } -func Scanner[T any](db *sqlx.DB, v T, s string, params ...any) func(func(T)) error { +func Scanner[T any](ctx context.Context, db *sqlx.DB, v T, s string, params ...any) func(func(T)) error { return func(fn func(T)) error { - rows, err := db.Queryx(s, params...) + rows, err := db.QueryxContext(ctx, s, params...) if err != nil { return err } @@ -80,8 +80,8 @@ func Scanner[T any](db *sqlx.DB, v T, s string, params ...any) func(func(T)) err } } -func ToMapSlice[V any](db *sqlx.DB, dest *[]map[string]V, sql string, params ...any) (err error) { - rows, err := db.Query(sql, params...) +func ToMapSlice[V any](ctx context.Context, db *sqlx.DB, dest *[]map[string]V, sql string, params ...any) (err error) { + rows, err := db.QueryContext(ctx, sql, params...) if err != nil { return err } @@ -113,8 +113,8 @@ func ToMapSlice[V any](db *sqlx.DB, dest *[]map[string]V, sql string, params ... return } -func GetToMap[V any](db *sqlx.DB, dest *map[string]V, sql string, params ...any) (err error) { - rows := db.QueryRowx(sql, params...) +func GetToMap[V any](ctx context.Context, db *sqlx.DB, dest *map[string]V, sql string, params ...any) (err error) { + rows := db.QueryRowxContext(ctx, sql, params...) columns, err := rows.Columns() if err != nil { return err