Skip to content

Commit 830b6ef

Browse files
author
Dmitriy Seredenko
committed
* Support $1, $n queries like in PostgresSQL
* Added tests for it * Small formatting update
1 parent 02def5b commit 830b6ef

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

conn.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"database/sql/driver"
66
"errors"
7+
"log"
8+
"regexp"
79
"strings"
810
"sync"
911
)
@@ -66,9 +68,19 @@ func (c *FakeConn) Prepare(query string) (driver.Stmt, error) {
6668
// context is for the preparation of the statement,
6769
// it must not store the context within the statement itself.
6870
func (c *FakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
69-
var firstStmt = &FakeStmt{q: query, connection: c} // Create statement
70-
firstStmt.placeholders = len(strings.Split(query, "?")) - 1 // Checking how many placeholders do we have
71-
queryParts := strings.Split(query, " ") // By First statement define the query type
71+
var firstStmt = &FakeStmt{q: query, connection: c}
72+
// Checking how many placeholders do we have
73+
if strings.Contains(query, "$1") {
74+
r, err := regexp.Compile(`[$]\d+`)
75+
if err != nil {
76+
log.Fatalf(`Cant't compile regexp with err [%v]`, err)
77+
}
78+
firstStmt.placeholders = len(strings.Split(r.ReplaceAllString(query, `$$$`), "$$")) - 1 // Postgres notation
79+
} else {
80+
firstStmt.placeholders = len(strings.Split(query, "?")) - 1 // Postgres notation
81+
}
82+
83+
queryParts := strings.Split(query, " ") // By First statement define the query type
7284
firstStmt.command = strings.ToUpper(queryParts[0])
7385
return firstStmt, nil
7486
}

response.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
package go_mocket
22

33
import (
4+
"database/sql"
45
"database/sql/driver"
56
"fmt"
67
"log"
78
"reflect"
89
"strings"
9-
"database/sql"
1010
)
11+
1112
const (
1213
DRIVER_NAME = "MOCK_FAKE_DRIVER"
1314
)

response_test.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func CreateUsersWithError(db *sql.DB) error {
4343
return err
4444
}
4545

46-
func InsertRecord(db *sql.DB) int64 {
46+
func InsertRecord(db *sql.DB) int64 {
4747
res, err := db.Exec(`INSERT INTO foo VALUES("bar", ?))`, "value")
4848
if err != nil {
4949
return 0
@@ -143,4 +143,26 @@ func TestResponses(t *testing.T) {
143143
t.Fatalf("Last insert id not returned. Expected: [%v] , Got: [%v]", mockedId, returnedId)
144144
}
145145
})
146+
147+
t.Run(`Recognise both ? and $1 Postgres placeholders for raw query`, func(t *testing.T) {
148+
t.Run("Question mark", func(t *testing.T) {
149+
testFunc := func(db *sql.DB) string {
150+
var name string
151+
err := db.QueryRow(`SELECT * FROM foo a = $1 AND b = $2 AND c = $3`, "value", "value2", "value3").Scan(&name)
152+
if err != nil {
153+
t.Fatalf("Test function failed [%v]", err)
154+
return ""
155+
}
156+
return name
157+
}
158+
159+
Catcher.Reset().NewMock().WithQuery("SELECT * FROM foo ").WithReply([]map[string]interface{}{{"name": "full_name"}})
160+
returnedName := testFunc(DB)
161+
162+
if returnedName != "full_name" {
163+
t.Fatalf("Returned name mismatches. Expected: [%v] , Got: [%v]", "full_name", returnedName)
164+
}
165+
166+
})
167+
})
146168
}

0 commit comments

Comments
 (0)