diff --git a/model/condition.go b/model/condition.go index da4ef85..7da10d8 100644 --- a/model/condition.go +++ b/model/condition.go @@ -2,6 +2,7 @@ package model type QueryCondition struct { where ParseWhere + from string fields string group string order SqlBuilder @@ -37,6 +38,12 @@ func Fields(fields string) Condition { } } +func From(from string) Condition { + return func(c *QueryCondition) { + c.from = from + } +} + func Group(group string) Condition { return func(c *QueryCondition) { c.group = group diff --git a/model/querycondition.go b/model/querycondition.go index 31c4514..8cb4a59 100644 --- a/model/querycondition.go +++ b/model/querycondition.go @@ -247,14 +247,17 @@ func BuildQuerySql[T Model](q *QueryCondition) (r string, args []any, err error) } tp := "select %s from %s %s %s %s %s %s %s" l := "" - + table := rr.Table() + if q.from != "" { + table = q.from + } if q.limit > 0 { l = fmt.Sprintf(" limit %d", q.limit) } if q.offset > 0 { l = fmt.Sprintf(" %s offset %d", l, q.offset) } - r = fmt.Sprintf(tp, q.fields, rr.Table(), j, w, groupBy, h, q.order.parseOrderBy(), l) + r = fmt.Sprintf(tp, q.fields, table, j, w, groupBy, h, q.order.parseOrderBy(), l) return } diff --git a/model/querycondition_test.go b/model/querycondition_test.go index d2ccbde..142118f 100644 --- a/model/querycondition_test.go +++ b/model/querycondition_test.go @@ -569,3 +569,67 @@ func Test_gets(t *testing.T) { }) } } + +func Test_finds(t *testing.T) { + type args struct { + db dbQuery + ctx context.Context + q *QueryCondition + } + type testCase[T Model] struct { + name string + args args + wantR []T + wantErr bool + } + var u user + tests := []testCase[options]{ + { + name: "sub query", + args: args{db: glob, ctx: ctx, q: Conditions( + From("(select * from wp_options where option_id <100) a"), + Where(SqlBuilder{{"option_id", ">", "50", "int"}}), + )}, + wantR: func() []options { + r, err := Select[options](ctx, "select * from (select * from wp_options where option_id <100) a where option_id>50") + if err != nil { + panic(err) + } + return r + }(), + wantErr: false, + }, + { + name: "mixed query", + args: args{db: glob, ctx: ctx, q: Conditions( + From("(select * from wp_options where option_id <100) a"), + Where(SqlBuilder{ + {"u.ID", "<", "50", "int"}}), + Join(SqlBuilder{ + {"left join", user.Table(u) + " u", "a.option_id=u.ID"}, + }), + Fields("u.user_login autoload,option_name,option_value"), + )}, + wantR: func() []options { + r, err := Select[options](ctx, "select u.user_login autoload,option_name,option_value from (select * from wp_options where option_id <100) a left join wp_users u on a.option_id=u.ID where `u`.`ID`<50") + if err != nil { + panic(err) + } + return r + }(), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotR, err := finds[options](tt.args.db, tt.args.ctx, tt.args.q) + if (err != nil) != tt.wantErr { + t.Errorf("finds() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotR, tt.wantR) { + t.Errorf("finds() gotR = %v, want %v", gotR, tt.wantR) + } + }) + } +}