diff --git a/model/query.go b/model/query.go index b8dad98..0e468b1 100644 --- a/model/query.go +++ b/model/query.go @@ -9,6 +9,19 @@ import ( "strings" ) +type count[T Model] struct { + t T + N int `json:"n,omitempty" db:"n" gorm:"n"` +} + +func (c count[T]) PrimaryKey() string { + return c.t.PrimaryKey() +} + +func (c count[T]) Table() string { + return c.t.Table() +} + func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page, pageSize int) (r []T, total int, err error) { if page < 1 || pageSize < 1 { return @@ -21,6 +34,7 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page In: q.In, Group: q.Group, From: q.From, + Fields: "count(*) n", } if q.Group != "" { qx.Fields = q.Fields @@ -32,13 +46,13 @@ func pagination[T Model](db dbQuery, ctx context.Context, q QueryCondition, page } qx.From = str.Join("( ", sq, " ) ", "table", number.ToString(rand.Int())) qx = QueryCondition{ - From: qx.From, - In: qx.In, + From: qx.From, + In: qx.In, + Fields: "count(*) n", } } - - n, err := GetField[T](ctx, "count(*)", qx) - total = str.ToInt[int](n) + n, err := gets[count[T]](db, ctx, qx) + total = n.N if err != nil || total < 1 { return }