diff --git a/internal/wpconfig/options.go b/internal/wpconfig/options.go index 84f4fa4..615636d 100644 --- a/internal/wpconfig/options.go +++ b/internal/wpconfig/options.go @@ -17,18 +17,15 @@ func InitOptions() error { if ctx == nil { ctx = context.Background() } - ops, err := model.SimpleFind[models.Options](ctx, model.SqlBuilder{{"autoload", "yes"}}, "option_name, option_value") + ops, err := model.FindToStringMap[models.Options](ctx, model.Conditions( + model.Where(model.SqlBuilder{{"autoload", "yes"}}), + model.Fields("option_name, option_value"), + )) if err != nil { return err } - if len(ops) == 0 { - ops, err = model.SimpleFind[models.Options](ctx, nil, "option_name, option_value") - if err != nil { - return err - } - } for _, option := range ops { - options.Store(option.OptionName, option.OptionValue) + options.Store(option["option_name"], option["option_value"]) } return nil } @@ -38,7 +35,7 @@ func GetOption(k string) string { if ok { return v } - vv, err := model.GetField[models.Options, string](ctx, "option_value", model.Conditions(model.Where(model.SqlBuilder{{"option_name", k}}))) + vv, err := model.GetField[models.Options](ctx, "option_value", model.Conditions(model.Where(model.SqlBuilder{{"option_name", k}}))) options.Store(k, vv) if err != nil { return "" diff --git a/model/query_test.go b/model/query_test.go index 7db74fe..a47482a 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -100,12 +100,15 @@ func (p post) Table() string { var ctx = context.Background() +var glob *SqlxQuery + func init() { db, err := sqlx.Open("mysql", "root:root@tcp(192.168.66.47:3306)/wordpress?charset=utf8mb4&parseTime=True&loc=Local") if err != nil { panic(err) } - InitDB(NewSqlxQuery(db, NewUniversalDb(nil, nil))) + glob = NewSqlxQuery(db, NewUniversalDb(nil, nil)) + InitDB(glob) } func TestFind(t *testing.T) { type args struct { diff --git a/model/querycondition.go b/model/querycondition.go index 7c8160c..cebd65a 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "github.com/fthvgb1/wp-go/helper/maps" "github.com/fthvgb1/wp-go/helper/slice" "strings" ) @@ -156,51 +155,65 @@ func column[V Model, T any](db dbQuery, ctx context.Context, fn func(V) (T, bool return } -func GetField[T Model, V any](ctx context.Context, field string, q *QueryCondition) (r V, err error) { - r, err = getField[T, V](globalBb, ctx, field, q) +func GetField[T Model](ctx context.Context, field string, q *QueryCondition) (r string, err error) { + r, err = getField[T](globalBb, ctx, field, q) return } -func getField[T Model, V any](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r V, err error) { - res, err := getToAnyMap[T](globalBb, ctx, q) +func getField[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { + if q.fields == "" || q.fields == "*" { + q.fields = field + } + res, err := getToStringMap[T](db, ctx, q) if err != nil { return } - r, ok := maps.GetStrAnyVal[V](res, field) + f := strings.Split(field, " ") + r, ok := res[f[len(f)-1]] if !ok { err = errors.New("not exists") } return } -func GetFieldFromDB[T Model, V any](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r V, err error) { - return getField[T, V](db, ctx, field, q) +func GetFieldFromDB[T Model](db dbQuery, ctx context.Context, field string, q *QueryCondition) (r string, err error) { + return getField[T](db, ctx, field, q) } -func findToAnyMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]any, err error) { +func getToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { rawSql, in, err := FindRawSql[T](q) if err != nil { return nil, err } - ctx = context.WithValue(ctx, "toMap", true) + ctx = context.WithValue(ctx, "toMap", "string") + err = db.Get(ctx, &r, rawSql, in...) + return +} +func GetToStringMap[T Model](ctx context.Context, q *QueryCondition) (r map[string]string, err error) { + r, err = getToStringMap[T](globalBb, ctx, q) + return +} + +func findToStringMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { + rawSql, in, err := FindRawSql[T](q) + if err != nil { + return nil, err + } + ctx = context.WithValue(ctx, "toMap", "string") err = db.Select(ctx, &r, rawSql, in...) return } -func FindToAnyMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]any, err error) { - r, err = findToAnyMap[T](globalBb, ctx, q) +func FindToStringMap[T Model](ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { + r, err = findToStringMap[T](globalBb, ctx, q) return } -func FindToAnyMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]any, err error) { - r, err = findToAnyMap[T](db, ctx, q) +func FindToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r []map[string]string, err error) { + r, err = findToStringMap[T](db, ctx, q) return } -func getToAnyMap[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]any, err error) { - rawSql, in, err := FindRawSql[T](q) - if err != nil { - return nil, err - } - ctx = context.WithValue(ctx, "toMap", true) - err = db.Get(ctx, &r, rawSql, in...) + +func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q *QueryCondition) (r map[string]string, err error) { + r, err = getToStringMap[T](db, ctx, q) return } diff --git a/model/querycondition_test.go b/model/querycondition_test.go index 1718b23..1d37608 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -274,44 +274,6 @@ func TestColumn(t *testing.T) { } } -func TestGetField(t *testing.T) { - type args struct { - ctx context.Context - field string - q *QueryCondition - } - type testCase[V any] struct { - name string - args args - wantR V - wantErr bool - } - tests := []testCase[string]{ - { - name: "t1", - args: args{ - ctx: ctx, - field: "option_value", - q: Conditions(Where(SqlBuilder{{"option_name", "blogname"}})), - }, - wantR: "记录并见证自己的成长", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotR, err := GetField[options, string](tt.args.ctx, tt.args.field, tt.args.q) - if (err != nil) != tt.wantErr { - t.Errorf("GetField() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotR, tt.wantR) { - t.Errorf("GetField() gotR = %v, want %v", gotR, tt.wantR) - } - }) - } -} - type options struct { OptionId uint64 `gorm:"column:option_id" db:"option_id" json:"option_id" form:"option_id"` OptionName string `gorm:"column:option_name" db:"option_name" json:"option_name" form:"option_name"` @@ -326,3 +288,190 @@ func (w options) PrimaryKey() string { func (w options) Table() string { return "wp_options" } + +func Test_getField(t *testing.T) { + { + name := "string" + db := glob + field := "option_value" + q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}})) + wantR := "记录并见证自己的成长" + wantErr := false + t.Run(name, func(t *testing.T) { + gotR, err := getField[options](db, ctx, field, q) + if (err != nil) != wantErr { + t.Errorf("getField() error = %v, wantErr %v", err, wantErr) + return + } + if !reflect.DeepEqual(gotR, wantR) { + t.Errorf("getField() gotR = %v, want %v", gotR, wantR) + } + }) + } + + { + name := "t2" + db := glob + field := "option_id" + q := Conditions(Where(SqlBuilder{{"option_name", "blogname"}})) + wantR := "3" + wantErr := false + t.Run(name, func(t *testing.T) { + gotR, err := getField[options](db, ctx, field, q) + if (err != nil) != wantErr { + t.Errorf("getField() error = %v, wantErr %v", err, wantErr) + return + } + if !reflect.DeepEqual(gotR, wantR) { + t.Errorf("getField() gotR = %v, want %v", gotR, wantR) + } + }) + } + { + name := "count(*)" + db := glob + field := "count(*)" + q := Conditions() + wantR := "386" + wantErr := false + t.Run(name, func(t *testing.T) { + gotR, err := getField[options](db, ctx, field, q) + if (err != nil) != wantErr { + t.Errorf("getField() error = %v, wantErr %v", err, wantErr) + return + } + if !reflect.DeepEqual(gotR, wantR) { + t.Errorf("getField() gotR = %v, want %v", gotR, wantR) + } + }) + } +} + +func Test_getToStringMap(t *testing.T) { + type args struct { + db dbQuery + ctx context.Context + q *QueryCondition + } + tests := []struct { + name string + args args + wantR map[string]string + wantErr bool + }{ + { + name: "t1", + args: args{ + db: glob, + ctx: ctx, + q: Conditions(Where(SqlBuilder{{"option_name", "users_can_register"}})), + }, + wantR: map[string]string{ + "option_id": "5", + "option_value": "0", + "option_name": "users_can_register", + "autoload": "yes", + }, + }, + { + name: "t2", + args: args{ + db: glob, + ctx: ctx, + q: Conditions( + Where(SqlBuilder{{"option_name", "users_can_register"}}), + Fields("option_id id"), + ), + }, + wantR: map[string]string{ + "id": "5", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := getToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("getToStringMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("getToStringMap() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +} + +func Test_findToStringMap(t *testing.T) { + type args struct { + db dbQuery + ctx context.Context + q *QueryCondition + } + tests := []struct { + name string + args args + wantR []map[string]string + wantErr bool + }{ + { + name: "t1", + args: args{ + db: glob, + ctx: ctx, + q: Conditions(Where(SqlBuilder{{"option_id", "5"}})), + }, + wantR: []map[string]string{{ + "option_id": "5", + "option_value": "0", + "option_name": "users_can_register", + "autoload": "yes", + }}, + wantErr: false, + }, + { + name: "t2", + args: args{ + db: glob, + ctx: ctx, + q: Conditions( + Where(SqlBuilder{{"option_id", "5"}}), + Fields("option_value,option_name"), + ), + }, + wantR: []map[string]string{{ + "option_value": "0", + "option_name": "users_can_register", + }}, + wantErr: false, + }, + { + name: "t3", + args: args{ + db: glob, + ctx: ctx, + q: Conditions( + Where(SqlBuilder{{"option_id", "5"}}), + Fields("option_value v,option_name k"), + ), + }, + wantR: []map[string]string{{ + "v": "0", + "k": "users_can_register", + }}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := findToStringMap[options](tt.args.db, tt.args.ctx, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("findToStringMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("findToStringMap() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +} diff --git a/model/sqxquery.go b/model/sqxquery.go index ef877cc..ec1359b 100644 --- a/model/sqxquery.go +++ b/model/sqxquery.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/fthvgb1/wp-go/helper/slice" "github.com/jmoiron/sqlx" "strconv" "strings" @@ -35,11 +36,11 @@ func SetGet(db *SqlxQuery, fn func(context.Context, any, string, ...any) error) func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params ...any) error { v := ctx.Value("toMap") if v != nil { - vv, ok := v.(bool) - if ok && vv { - d, ok := dest.(*[]map[string]any) - if ok { - return r.toMapSlice(d, sql, params...) + vv, ok := v.(string) + if ok && vv != "" { + switch vv { + case "string": + return ToMapSlice(r.sqlx, dest.(*[]map[string]string), sql, params...) } } } @@ -49,77 +50,39 @@ func (r *SqlxQuery) Selects(ctx context.Context, dest any, sql string, params .. func (r *SqlxQuery) Gets(ctx context.Context, dest any, sql string, params ...any) error { v := ctx.Value("toMap") if v != nil { - vv, ok := v.(bool) - if ok && vv { - d, ok := dest.(*map[string]any) - if ok { - return r.toMap(d, sql, params...) + vv, ok := v.(string) + if ok && vv != "" { + switch vv { + case "string": + return GetToMap(r.sqlx, dest.(*map[string]string), sql, params...) } } } return r.sqlx.Get(dest, sql, params...) } -func (r *SqlxQuery) toMap(dest *map[string]any, sql string, params ...any) (err error) { - rows := r.sqlx.QueryRowx(sql, params...) - columns, err := rows.Columns() - if err != nil { - return err - } - columnLen := len(columns) - c := make([]any, columnLen) - for i, _ := range c { - var a any - c[i] = &a - } - err = rows.Scan(c...) - if err != nil { - return - } - v := make(map[string]any) - for i, data := range c { - s, ok := data.(*any) - if ok { - ss, ok := (*s).([]uint8) - if ok { - data = string(ss) - } - } - v[columns[i]] = data - } - *dest = v - return -} - -func (r *SqlxQuery) toMapSlice(dest *[]map[string]any, sql string, params ...any) (err error) { - rows, err := r.sqlx.Query(sql, params...) +func ToMapSlice[V any](db *sqlx.DB, dest *[]map[string]V, sql string, params ...any) (err error) { + rows, err := db.Query(sql, params...) columns, err := rows.Columns() if err != nil { return err } defer rows.Close() columnLen := len(columns) - c := make([]any, columnLen) + c := make([]*V, columnLen) for i, _ := range c { - var a any + var a V c[i] = &a } - var m []map[string]any + var m []map[string]V for rows.Next() { - err = rows.Scan(c...) + err = rows.Scan(slice.ToAnySlice(c)...) if err != nil { return } - v := make(map[string]any) + v := make(map[string]V) for i, data := range c { - s, ok := data.(*any) - if ok { - ss, ok := (*s).([]uint8) - if ok { - data = string(ss) - } - } - v[columns[i]] = data + v[columns[i]] = *data } m = append(m, v) } @@ -127,6 +90,30 @@ func (r *SqlxQuery) toMapSlice(dest *[]map[string]any, sql string, params ...any return } +func GetToMap[V any](db *sqlx.DB, dest *map[string]V, sql string, params ...any) (err error) { + rows := db.QueryRowx(sql, params...) + columns, err := rows.Columns() + if err != nil { + return err + } + columnLen := len(columns) + c := make([]*V, columnLen) + for i, _ := range c { + var a V + c[i] = &a + } + err = rows.Scan(slice.ToAnySlice(c)...) + if err != nil { + return + } + v := make(map[string]V) + for i, data := range c { + v[columns[i]] = *data + } + *dest = v + return +} + func FormatSql(sql string, params ...any) string { for _, param := range params { switch param.(type) {