Skip to content

Commit 23fc745

Browse files
committed
add Match for mockdb
1 parent 24a6c71 commit 23fc745

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

sqlmock/sqlmock.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"database/sql/driver"
88
"encoding/json"
99
"fmt"
10+
"regexp"
1011
"sync"
1112
)
1213

@@ -23,6 +24,7 @@ var (
2324
// ExpectedQuery 表示一个预期的查询
2425
type ExpectedQuery struct {
2526
query string
27+
matcher *regexp.Regexp
2628
args []driver.Value
2729
rows [][]driver.Value
2830
columns []string
@@ -55,6 +57,19 @@ func (m *MockDB) ExpectQuery(query string, args ...driver.Value) *ExpectedQuery
5557
return eq
5658
}
5759

60+
// ExpectQuery 期望一个特定的查询
61+
func (m *MockDB) Macth(matcher string, args ...driver.Value) *ExpectedQuery {
62+
m.mu.Lock()
63+
defer m.mu.Unlock()
64+
65+
eq := &ExpectedQuery{
66+
args: args,
67+
matcher: regexp.MustCompile(matcher),
68+
}
69+
m.expected = append(m.expected, eq)
70+
return eq
71+
}
72+
5873
func (eq *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
5974
eq.args = args
6075
return eq
@@ -141,6 +156,16 @@ func (ms *MockStmt) Exec(args []driver.Value) (driver.Result, error) {
141156
defer ms.mockDB.mu.Unlock()
142157

143158
for i, expected := range ms.mockDB.expected {
159+
if expected.matcher != nil && expected.matcher.MatchString(ms.query) && matchArgs(ms.mockDB.expected[i].args, args) {
160+
ms.mockDB.expected = append(ms.mockDB.expected[:i], ms.mockDB.expected[i+1:]...)
161+
162+
if expected.err != nil {
163+
return nil, expected.err
164+
}
165+
166+
return &MockResult{}, nil
167+
}
168+
144169
if CompareSQL(expected.query, ms.query) && matchArgs(expected.args, args) {
145170
ms.mockDB.expected = append(ms.mockDB.expected[:i], ms.mockDB.expected[i+1:]...)
146171

@@ -160,6 +185,16 @@ func (ms *MockStmt) Query(args []driver.Value) (driver.Rows, error) {
160185
defer ms.mockDB.mu.Unlock()
161186

162187
for i, expected := range ms.mockDB.expected {
188+
if expected.matcher != nil && expected.matcher.MatchString(ms.query) && matchArgs(ms.mockDB.expected[i].args, args) {
189+
ms.mockDB.expected = append(ms.mockDB.expected[:i], ms.mockDB.expected[i+1:]...)
190+
191+
if expected.err != nil {
192+
return nil, expected.err
193+
}
194+
195+
return &MockRows{columns: expected.columns, rows: expected.rows}, nil
196+
}
197+
163198
if CompareSQL(expected.query, ms.query) && matchArgs(expected.args, args) {
164199
ms.mockDB.expected = append(ms.mockDB.expected[:i], ms.mockDB.expected[i+1:]...)
165200

sqlmock/sqlmock_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,64 @@ func TestMultipleQueries(t *testing.T) {
213213
t.Errorf("unexpected user: %+v", user2)
214214
}
215215
}
216+
func TestMockDB_Match(t *testing.T) {
217+
// 1. 创建 mock 数据库
218+
mockDB := sqlmock.NewMock()
219+
220+
// 2. 期望一个匹配查询并设置返回值
221+
pattern := `^SELECT\s+[a-zA-Z0-9_,\s]+FROM\s+[a-zA-Z0-9_]+\s+WHERE\s+[a-zA-Z0-9_]+\s*=\s*\?$`
222+
223+
mockDB.Macth(pattern).
224+
WithArgs("Alice").
225+
WillReturnRows([]string{"id", "name", "age"}, [][]driver.Value{
226+
{2, "Alice", 25},
227+
})
228+
229+
// 3. 打开数据库连接
230+
db, err := mockDB.Open("mock")
231+
if err != nil {
232+
t.Fatalf("failed to open database: %v", err)
233+
}
234+
235+
// 4. 创建仓库实例
236+
repo := &UserRepository{db: db}
237+
238+
// 5. 执行测试
239+
var user User
240+
err = repo.db.QueryRow("SELECT id, name, age FROM users WHERE name = ?", "Alice").Scan(&user.ID, &user.Name, &user.Age)
241+
if err != nil {
242+
t.Fatalf("unexpected error: %v", err)
243+
}
244+
245+
// 6. 验证结果
246+
if user.ID != 2 || user.Name != "Alice" || user.Age != 25 {
247+
t.Errorf("unexpected user: %+v", user)
248+
}
249+
}
250+
251+
func TestMockDB_Match_NoMatch(t *testing.T) {
252+
// 1. 创建 mock 数据库
253+
mockDB := sqlmock.NewMock()
254+
255+
// 2. 期望一个匹配查询并设置返回值
256+
mockDB.Macth("SELECT id, name, age FROM users WHERE name = ?", "Alice").
257+
WillReturnRows([]string{"id", "name", "age"}, [][]driver.Value{
258+
{2, "Alice", 25},
259+
})
260+
261+
// 3. 打开数据库连接
262+
db, err := mockDB.Open("mock")
263+
if err != nil {
264+
t.Fatalf("failed to open database: %v", err)
265+
}
266+
267+
// 4. 创建仓库实例
268+
repo := &UserRepository{db: db}
269+
270+
// 5. 执行测试
271+
var user User
272+
err = repo.db.QueryRow("SELECT id, name, age FROM users WHERE name = ?", "Bob").Scan(&user.ID, &user.Name, &user.Age)
273+
if err == nil {
274+
t.Fatalf("expected error, got nil")
275+
}
276+
}

0 commit comments

Comments
 (0)