diff --git a/models/model.go b/models/model.go index 009a303..9032d96 100644 --- a/models/model.go +++ b/models/model.go @@ -4,6 +4,7 @@ import ( "fmt" "github/fthvgb1/wp-go/db" "github/fthvgb1/wp-go/helper" + "strconv" "strings" ) @@ -25,9 +26,10 @@ func (m model[T]) Table() string { type SqlBuilder [][]string -func (w SqlBuilder) parseWhere() (string, []interface{}) { +func (w SqlBuilder) parseWhere(in ...[]interface{}) (string, []interface{}) { var s strings.Builder args := make([]interface{}, 0, len(w)) + c := 0 for _, ss := range w { if len(ss) == 2 { s.WriteString("`") @@ -35,13 +37,35 @@ func (w SqlBuilder) parseWhere() (string, []interface{}) { s.WriteString("`=? and ") args = append(args, ss[1]) } - if len(ss) == 3 { + if len(ss) >= 3 { s.WriteString("`") s.WriteString(ss[0]) s.WriteString("`") s.WriteString(ss[1]) + if ss[1] == "in" && len(in) > 0 { + s.WriteString(" (") + for _, p := range in[c] { + s.WriteString("?,") + args = append(args, p) + } + sx := s.String() + s.Reset() + s.WriteString(strings.TrimRight(sx, ",")) + s.WriteString(")") + c++ + s.WriteString(" and ") + continue + } s.WriteString(" ? and ") - args = append(args, ss[2]) + if len(ss) == 4 && ss[3] == "int" { + i, _ := strconv.Atoi(ss[2]) + args = append(args, i) + } else if len(ss) == 4 && ss[3] == "float" { + i, _ := strconv.ParseFloat(ss[2], 64) + args = append(args, i) + } else { + args = append(args, ss[2]) + } } } ss := strings.TrimRight(s.String(), "and ") @@ -91,9 +115,26 @@ func (w SqlBuilder) parseJoin() string { return s.String() } -func (m model[T]) SimplePagination(where SqlBuilder, fields string, page, pageSize int, order SqlBuilder, join SqlBuilder) (r []T, total int, err error) { +func (m model[T]) Find(where SqlBuilder, fields string, order SqlBuilder, join SqlBuilder, limit int, in ...[]interface{}) (r []T, err error) { var rr T - w, args := where.parseWhere() + w, args := where.parseWhere(in...) + j := join.parseJoin() + tp := "select %s from %s %s %s %s %s" + l := "" + if limit > 0 { + l = fmt.Sprintf(" limit %d", limit) + } + sql := fmt.Sprintf(tp, fields, rr.Table(), j, w, order.parseOrderBy(), l) + err = db.Db.Select(&r, sql, args...) + if err != nil { + return + } + return +} + +func (m model[T]) SimplePagination(where SqlBuilder, fields string, page, pageSize int, order SqlBuilder, join SqlBuilder, in ...[]interface{}) (r []T, total int, err error) { + var rr T + w, args := where.parseWhere(in...) n := struct { N int `db:"n" json:"n"` }{} @@ -134,9 +175,9 @@ func (m model[T]) FindOneById(id int) (T, error) { return r, nil } -func (m model[T]) FirstOne(where SqlBuilder, fields string) (T, error) { +func (m model[T]) FirstOne(where SqlBuilder, fields string, in ...[]interface{}) (T, error) { var r T - w, args := where.parseWhere() + w, args := where.parseWhere(in...) tp := "select %s from %s %s" sql := fmt.Sprintf(tp, fields, r.Table(), w) err := db.Db.Get(&r, sql, args...) @@ -146,9 +187,9 @@ func (m model[T]) FirstOne(where SqlBuilder, fields string) (T, error) { return r, nil } -func (m model[T]) LastOne(where SqlBuilder, fields string) (T, error) { +func (m model[T]) LastOne(where SqlBuilder, fields string, in ...[]interface{}) (T, error) { var r T - w, args := where.parseWhere() + w, args := where.parseWhere(in...) tp := "select %s from %s %s order by %s desc limit 1" sql := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) err := db.Db.Get(&r, sql, args...) @@ -158,10 +199,10 @@ func (m model[T]) LastOne(where SqlBuilder, fields string) (T, error) { return r, nil } -func (m model[T]) FindMany(where SqlBuilder, fields string) ([]T, error) { +func (m model[T]) SimpleFind(where SqlBuilder, fields string, in ...[]interface{}) ([]T, error) { var r []T var rr T - w, args := where.parseWhere() + w, args := where.parseWhere(in...) tp := "select %s from %s %s" sql := fmt.Sprintf(tp, fields, rr.Table(), w) err := db.Db.Select(&r, sql, args...)