diff --git a/models/model.go b/models/model.go index 991da0e..1b50190 100644 --- a/models/model.go +++ b/models/model.go @@ -18,7 +18,7 @@ type Model interface { } type ParseWhere interface { - ParseWhere(in ...[]any) (string, []any) + ParseWhere(...[]any) (string, []any, error) } type SqlBuilder [][]string @@ -57,19 +57,26 @@ func (w SqlBuilder) parseIn(ss []string, s *strings.Builder, c *int, args *[]any return t } -func (w SqlBuilder) parseType(ss []string, args *[]any) { +func (w SqlBuilder) parseType(ss []string, args *[]any) error { if len(ss) == 4 && ss[3] == "int" { - i, _ := strconv.Atoi(ss[2]) + i, err := strconv.ParseInt(ss[2], 10, 64) + if err != nil { + return err + } *args = append(*args, i) } else if len(ss) == 4 && ss[3] == "float" { - i, _ := strconv.ParseFloat(ss[2], 64) + i, err := strconv.ParseFloat(ss[2], 64) + if err != nil { + return err + } *args = append(*args, i) } else { *args = append(*args, ss[2]) } + return nil } -func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any) { +func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any, error) { var s strings.Builder args := make([]any, 0, len(w)) c := 0 @@ -86,7 +93,10 @@ func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any) { continue } s.WriteString(" ? and ") - w.parseType(ss, &args) + err := w.parseType(ss, &args) + if err != nil { + return "", nil, err + } } else if len(ss) >= 5 && len(ss)%5 == 0 { j := len(ss) / 5 for i := 0; i < j; i++ { @@ -109,7 +119,10 @@ func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any) { continue } s.WriteString(" ? and ") - w.parseType(ss[start+1:end], &args) + err := w.parseType(ss[start+1:end], &args) + if err != nil { + return "", nil, err + } } st := s.String() st = strings.TrimRight(st, "and ") @@ -125,7 +138,7 @@ func (w SqlBuilder) ParseWhere(in ...[]any) (string, []any) { s.WriteString(ss) ss = s.String() } - return ss, args + return ss, args, nil } func (w SqlBuilder) parseOrderBy() string { @@ -167,7 +180,10 @@ func (w SqlBuilder) parseJoin() string { func SimplePagination[T Model](where ParseWhere, fields, group string, page, pageSize int, order SqlBuilder, join SqlBuilder, in ...[]any) (r []T, total int, err error) { var rr T - w, args := where.ParseWhere(in...) + w, args, err := where.ParseWhere(in...) + if err != nil { + return r, total, err + } n := struct { N int `db:"n" json:"n"` }{} @@ -225,10 +241,13 @@ func FindOneById[T Model, I ~int | ~uint64 | ~int64 | ~int32](id I) (T, error) { func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ...[]any) (T, error) { var r T - w, args := where.ParseWhere(in...) + w, args, err := where.ParseWhere(in...) + if err != nil { + return r, err + } tp := "select %s from %s %s %s" sql := fmt.Sprintf(tp, fields, r.Table(), w, order.parseOrderBy()) - err := db.Db.Get(&r, sql, args...) + err = db.Db.Get(&r, sql, args...) if err != nil { return r, err } @@ -237,10 +256,13 @@ func FirstOne[T Model](where ParseWhere, fields string, order SqlBuilder, in ... func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { var r T - w, args := where.ParseWhere(in...) + w, args, err := where.ParseWhere(in...) + if err != nil { + return r, err + } tp := "select %s from %s %s order by %s desc limit 1" sql := fmt.Sprintf(tp, fields, r.Table(), w, r.PrimaryKey()) - err := db.Db.Get(&r, sql, args...) + err = db.Db.Get(&r, sql, args...) if err != nil { return r, err } @@ -250,10 +272,13 @@ func LastOne[T Model](where ParseWhere, fields string, in ...[]any) (T, error) { func SimpleFind[T Model](where ParseWhere, fields string, in ...[]any) ([]T, error) { var r []T var rr T - w, args := where.ParseWhere(in...) + w, args, err := where.ParseWhere(in...) + if err != nil { + return r, err + } tp := "select %s from %s %s" sql := fmt.Sprintf(tp, fields, rr.Table(), w) - err := db.Db.Select(&r, sql, args...) + err = db.Db.Select(&r, sql, args...) if err != nil { return r, err } @@ -276,7 +301,10 @@ func Find[T Model](where ParseWhere, fields, group string, order SqlBuilder, joi w := "" var args []any if where != nil { - w, args = where.ParseWhere(in...) + w, args, err = where.ParseWhere(in...) + if err != nil { + return r, err + } } j := join.parseJoin()