-
-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6bed17c
commit 6164ad6
Showing
6 changed files
with
418 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package main | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
) | ||
|
||
// this type emulate go-mssqldb's ReturnStatus type, used to get rc from SQL Server stored procedures | ||
// https://github.com/microsoft/go-mssqldb/blob/main/mssql.go | ||
type ReturnStatus int32 | ||
|
||
func execWithNamedOutputArgs(db *sql.DB, outputArg *string, inputOutputArg *string) (err error) { | ||
_, err = db.Exec("EXEC spWithNamedOutputParameters", | ||
sql.Named("outArg", sql.Out{Dest: outputArg}), | ||
sql.Named("inoutArg", sql.Out{In: true, Dest: inputOutputArg}), | ||
) | ||
if err != nil { | ||
return | ||
} | ||
return | ||
} | ||
|
||
func execWithTypedOutputArgs(db *sql.DB, rcArg *ReturnStatus) (err error) { | ||
if _, err = db.Exec("EXEC spWithReturnCode", rcArg); err != nil { | ||
return | ||
} | ||
return | ||
} | ||
|
||
func main() { | ||
// @NOTE: the real connection is not required for tests | ||
db, err := sql.Open("mssql", "myconnectionstring") | ||
if err != nil { | ||
panic(err) | ||
} | ||
defer db.Close() | ||
|
||
outputArg := "" | ||
inputOutputArg := "abcInput" | ||
|
||
if err = execWithNamedOutputArgs(db, &outputArg, &inputOutputArg); err != nil { | ||
panic(err) | ||
} | ||
|
||
rcArg := new(ReturnStatus) | ||
if err = execWithTypedOutputArgs(db, rcArg); err != nil { | ||
panic(err) | ||
} | ||
|
||
if _, err = fmt.Printf("outputArg: %s, inputOutputArg: %s, rcArg: %d", outputArg, inputOutputArg, *rcArg); err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package main | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/DATA-DOG/go-sqlmock" | ||
) | ||
|
||
func TestNamedOutputArgs(t *testing.T) { | ||
db, mock, err := sqlmock.New() | ||
if err != nil { | ||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||
} | ||
defer db.Close() | ||
|
||
inOutInputValue := "abcInput" | ||
mock.ExpectExec("EXEC spWithNamedOutputParameters"). | ||
WithArgs( | ||
sqlmock.NamedOutputArg("outArg", "123Output"), | ||
sqlmock.NamedInputOutputArg("inoutArg", &inOutInputValue, "abcOutput"), | ||
). | ||
WillReturnResult(sqlmock.NewResult(1, 1)) | ||
|
||
// now we execute our method | ||
outArg := "" | ||
inoutArg := "abcInput" | ||
if err = execWithNamedOutputArgs(db, &outArg, &inoutArg); err != nil { | ||
t.Errorf("error was not expected while updating stats: %s", err) | ||
} | ||
|
||
// we make sure that all expectations were met | ||
if err := mock.ExpectationsWereMet(); err != nil { | ||
t.Errorf("there were unfulfilled expectations: %s", err) | ||
} | ||
|
||
if outArg != "123Output" { | ||
t.Errorf("unexpected outArg value") | ||
} | ||
|
||
if inoutArg != "abcOutput" { | ||
t.Errorf("unexpected inoutArg value") | ||
} | ||
} | ||
|
||
func TestTypedOutputArgs(t *testing.T) { | ||
rcArg := new(ReturnStatus) // here we will store the return code | ||
|
||
valueConverter := sqlmock.NewPassthroughValueConverter(rcArg) // we need this converter to bypass the default ValueConverter logic that alter original value's type | ||
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(valueConverter)) | ||
if err != nil { | ||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||
} | ||
defer db.Close() | ||
|
||
rcFromSp := ReturnStatus(123) // simulate the return code from the stored procedure | ||
mock.ExpectExec("EXEC spWithReturnCode"). | ||
WithArgs( | ||
sqlmock.TypedOutputArg(&rcFromSp), // using this func we can provide the expected type and value | ||
). | ||
WillReturnResult(sqlmock.NewResult(1, 1)) | ||
|
||
// now we execute our method | ||
if err = execWithTypedOutputArgs(db, rcArg); err != nil { | ||
t.Errorf("error was not expected while updating stats: %s", err) | ||
} | ||
|
||
// we make sure that all expectations were met | ||
if err := mock.ExpectationsWereMet(); err != nil { | ||
t.Errorf("there were unfulfilled expectations: %s", err) | ||
} | ||
|
||
if *rcArg != 123 { | ||
t.Errorf("unexpected rcArg value") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package sqlmock | ||
|
||
import ( | ||
"database/sql" | ||
"database/sql/driver" | ||
"fmt" | ||
"reflect" | ||
) | ||
|
||
type namedInOutValue struct { | ||
Name string | ||
ExpectedInValue interface{} | ||
ReturnedOutValue interface{} | ||
In bool | ||
} | ||
|
||
// Match implements the Argument interface, allowing check if the given value matches the expected input value provided using NamedInputOutputArg function. | ||
func (n namedInOutValue) Match(v driver.Value) bool { | ||
out, ok := v.(sql.Out) | ||
|
||
return ok && out.In == n.In && (!n.In || reflect.DeepEqual(out.Dest, n.ExpectedInValue)) | ||
} | ||
|
||
// NamedInputArg can ben used to simulate an output value passed back from the database. | ||
// returnedOutValue can be a value or a pointer to the value. | ||
func NamedOutputArg(name string, returnedOutValue interface{}) interface{} { | ||
return namedInOutValue{ | ||
Name: name, | ||
ReturnedOutValue: returnedOutValue, | ||
In: false, | ||
} | ||
} | ||
|
||
// NamedInputOutputArg can be used to both check if expected input value is provided and to simulate an output value passed back from the database. | ||
// expectedInValue must be a pointer to the value, returnedOutValue can be a value or a pointer to the value. | ||
func NamedInputOutputArg(name string, expectedInValue interface{}, returnedOutValue interface{}) interface{} { | ||
return namedInOutValue{ | ||
Name: name, | ||
ExpectedInValue: expectedInValue, | ||
ReturnedOutValue: returnedOutValue, | ||
In: true, | ||
} | ||
} | ||
|
||
type typedOutValue struct { | ||
TypeName string | ||
ReturnedOutValue interface{} | ||
} | ||
|
||
// Match implements the Argument interface, allowing check if the given value matches the expected type provided using TypedOutputArg function. | ||
func (n typedOutValue) Match(v driver.Value) bool { | ||
return n.TypeName == fmt.Sprintf("%T", v) | ||
} | ||
|
||
// TypeOutputArg can be used to simulate an output value passed back from the database, setting value based on the type. | ||
// returnedOutValue must be a pointer to the value. | ||
func TypedOutputArg(returnedOutValue interface{}) interface{} { | ||
return typedOutValue{ | ||
TypeName: fmt.Sprintf("%T", returnedOutValue), | ||
ReturnedOutValue: returnedOutValue, | ||
} | ||
} | ||
|
||
func setOutputValues(currentArgs []driver.NamedValue, expectedArgs []driver.Value) { | ||
for _, expectedArg := range expectedArgs { | ||
if outVal, ok := expectedArg.(namedInOutValue); ok { | ||
for _, currentArg := range currentArgs { | ||
if currentArg.Name == outVal.Name { | ||
if sqlOut, ok := currentArg.Value.(sql.Out); ok { | ||
reflect.ValueOf(sqlOut.Dest).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue))) | ||
} | ||
|
||
break | ||
} | ||
} | ||
} | ||
|
||
if outVal, ok := expectedArg.(typedOutValue); ok { | ||
for _, currentArg := range currentArgs { | ||
if fmt.Sprintf("%T", currentArg.Value) == outVal.TypeName { | ||
reflect.ValueOf(currentArg.Value).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue))) | ||
|
||
break | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package sqlmock | ||
|
||
import ( | ||
"database/sql/driver" | ||
"fmt" | ||
) | ||
|
||
type PassthroughValueConverter struct { | ||
passthroughTypes []string | ||
} | ||
|
||
func NewPassthroughValueConverter(typeSamples ...interface{}) *PassthroughValueConverter { | ||
c := &PassthroughValueConverter{} | ||
|
||
for _, sampleValue := range typeSamples { | ||
c.passthroughTypes = append(c.passthroughTypes, fmt.Sprintf("%T", sampleValue)) | ||
} | ||
|
||
return c | ||
} | ||
|
||
func (c *PassthroughValueConverter) ConvertValue(v interface{}) (driver.Value, error) { | ||
valueType := fmt.Sprintf("%T", v) | ||
for _, passthroughType := range c.passthroughTypes { | ||
if valueType == passthroughType { | ||
return v, nil | ||
} | ||
} | ||
|
||
return driver.DefaultParameterConverter.ConvertValue(v) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.