diff --git a/model/model.go b/model/model.go index c2d83a8..0a426fe 100644 --- a/model/model.go +++ b/model/model.go @@ -28,3 +28,13 @@ type dbQuery interface { } type SqlBuilder [][]string + +func Table[T Model]() string { + var r T + return r.Table() +} + +func PrimaryKey[T Model]() string { + var r T + return r.PrimaryKey() +} diff --git a/model/query.go b/model/query.go index 8d7da6b..fb2268e 100644 --- a/model/query.go +++ b/model/query.go @@ -2,7 +2,6 @@ package model import ( "context" - "fmt" "github.com/fthvgb1/wp-go/helper/number" str "github.com/fthvgb1/wp-go/helper/strings" "golang.org/x/exp/constraints" @@ -59,13 +58,12 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition) (r [ } func FindOneById[T Model, I constraints.Integer](ctx context.Context, id I) (T, error) { - var r T - sq := fmt.Sprintf("select * from `%s` where `%s`=?", r.Table(), r.PrimaryKey()) - err := globalBb.Get(ctx, &r, sq, id) - if err != nil { - return r, err - } - return r, nil + return gets[T](globalBb, ctx, QueryCondition{ + Fields: "*", + Where: SqlBuilder{ + {PrimaryKey[T](), "=", number.ToString(id), "int"}, + }, + }) } func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, order SqlBuilder, in ...[]any) (r T, err error) { @@ -84,23 +82,13 @@ func FirstOne[T Model](ctx context.Context, where ParseWhere, fields string, ord } func LastOne[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (T, error) { - var r T - var w string - var args []any - var err error - if where != nil { - w, args, err = where.ParseWhere(&in) - if err != nil { - return r, err - } - } - tp := "select %s from %s %s order by %s desc limit 1" - sq := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) - err = globalBb.Get(ctx, &r, sq, args...) - if err != nil { - return r, err - } - return r, nil + return gets[T](globalBb, ctx, Conditions( + Where(where), + Fields(fields), + In(in...), + Order(SqlBuilder{{PrimaryKey[T](), "desc"}}), + Limit(1), + )) } func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, in ...[]any) (r []T, err error) { @@ -116,10 +104,10 @@ func SimpleFind[T Model](ctx context.Context, where ParseWhere, fields string, i return r, nil } +// Select 如果查询的为T的表名,可以使用 {table}来代替 func Select[T Model](ctx context.Context, sql string, params ...any) ([]T, error) { var r []T - var rr T - sql = strings.Replace(sql, "{table}", rr.Table(), -1) + sql = strings.Replace(sql, "{table}", Table[T](), -1) err := globalBb.Select(ctx, &r, sql, params...) if err != nil { return r, err @@ -146,6 +134,7 @@ func Find[T Model](ctx context.Context, where ParseWhere, fields, group string, return } +// Get 可以使用 {table}来替代 T的表名 func Get[T Model](ctx context.Context, sql string, params ...any) (r T, err error) { sql = strings.Replace(sql, "{table}", r.Table(), -1) err = globalBb.Get(ctx, &r, sql, params...) diff --git a/model/querycondition.go b/model/querycondition.go index ec3c8c9..e1b703f 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -218,7 +218,6 @@ func GetToStringMapFromDB[T Model](db dbQuery, ctx context.Context, q QueryCondi } func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error) { - var rr T w := "" if q.Where != nil { w, args, err = q.Where.ParseWhere(&q.In) @@ -252,7 +251,7 @@ func BuildQuerySql[T Model](q QueryCondition) (r string, args []any, err error) } tp := "select %s from %s %s %s %s %s %s %s" l := "" - table := rr.Table() + table := Table[T]() if q.From != "" { table = q.From } diff --git a/model/querycondition_test.go b/model/querycondition_test.go index 636489b..ae303f1 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -333,7 +333,7 @@ func Test_getField(t *testing.T) { db := glob field := "count(*)" q := Conditions() - wantR := "387" + wantR := "385" wantErr := false t.Run(name, func(t *testing.T) { gotR, err := getField[options](db, ctx, field, q)