Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue 209: add output parameters support #335

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions examples/output/output.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package main

import (
"database/sql"
"fmt"
)

// this type emulate go-mssqldb's ReturnStatus type, used to get rc from SQL Server stored procedures
// https://github.com/microsoft/go-mssqldb/blob/main/mssql.go
type ReturnStatus int32

func execWithNamedOutputArgs(db *sql.DB, outputArg *string, inputOutputArg *string) (err error) {
_, err = db.Exec("EXEC spWithNamedOutputParameters",
sql.Named("outArg", sql.Out{Dest: outputArg}),
sql.Named("inoutArg", sql.Out{In: true, Dest: inputOutputArg}),
)
if err != nil {
return
}
return
}

func execWithTypedOutputArgs(db *sql.DB, rcArg *ReturnStatus) (err error) {
if _, err = db.Exec("EXEC spWithReturnCode", rcArg); err != nil {
return
}
return
}

func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mssql", "myconnectionstring")
if err != nil {
panic(err)
}
defer db.Close()

outputArg := ""
inputOutputArg := "abcInput"

if err = execWithNamedOutputArgs(db, &outputArg, &inputOutputArg); err != nil {
panic(err)
}

rcArg := new(ReturnStatus)
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
panic(err)
}

if _, err = fmt.Printf("outputArg: %s, inputOutputArg: %s, rcArg: %d", outputArg, inputOutputArg, *rcArg); err != nil {
panic(err)
}
}
75 changes: 75 additions & 0 deletions examples/output/output_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package main

import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
)

func TestNamedOutputArgs(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

inOutInputValue := "abcInput"
mock.ExpectExec("EXEC spWithNamedOutputParameters").
WithArgs(
sqlmock.NamedOutputArg("outArg", "123Output"),
sqlmock.NamedInputOutputArg("inoutArg", &inOutInputValue, "abcOutput"),
).
WillReturnResult(sqlmock.NewResult(1, 1))

// now we execute our method
outArg := ""
inoutArg := "abcInput"
if err = execWithNamedOutputArgs(db, &outArg, &inoutArg); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}

// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}

if outArg != "123Output" {
t.Errorf("unexpected outArg value")
}

if inoutArg != "abcOutput" {
t.Errorf("unexpected inoutArg value")
}
}

func TestTypedOutputArgs(t *testing.T) {
rcArg := new(ReturnStatus) // here we will store the return code

valueConverter := sqlmock.NewPassthroughValueConverter(rcArg) // we need this converter to bypass the default ValueConverter logic that alter original value's type
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(valueConverter))
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

rcFromSp := ReturnStatus(123) // simulate the return code from the stored procedure
mock.ExpectExec("EXEC spWithReturnCode").
WithArgs(
sqlmock.TypedOutputArg(&rcFromSp), // using this func we can provide the expected type and value
).
WillReturnResult(sqlmock.NewResult(1, 1))

// now we execute our method
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}

// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}

if *rcArg != 123 {
t.Errorf("unexpected rcArg value")
}
}
88 changes: 88 additions & 0 deletions output_args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package sqlmock

import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)

type namedInOutValue struct {
Name string
ExpectedInValue interface{}
ReturnedOutValue interface{}
In bool
}

// Match implements the Argument interface, allowing check if the given value matches the expected input value provided using NamedInputOutputArg function.
func (n namedInOutValue) Match(v driver.Value) bool {
out, ok := v.(sql.Out)

return ok && out.In == n.In && (!n.In || reflect.DeepEqual(out.Dest, n.ExpectedInValue))
}

// NamedInputArg can ben used to simulate an output value passed back from the database.
// returnedOutValue can be a value or a pointer to the value.
func NamedOutputArg(name string, returnedOutValue interface{}) interface{} {
return namedInOutValue{
Name: name,
ReturnedOutValue: returnedOutValue,
In: false,
}
}

// NamedInputOutputArg can be used to both check if expected input value is provided and to simulate an output value passed back from the database.
// expectedInValue must be a pointer to the value, returnedOutValue can be a value or a pointer to the value.
func NamedInputOutputArg(name string, expectedInValue interface{}, returnedOutValue interface{}) interface{} {
return namedInOutValue{
Name: name,
ExpectedInValue: expectedInValue,
ReturnedOutValue: returnedOutValue,
In: true,
}
}

type typedOutValue struct {
TypeName string
ReturnedOutValue interface{}
}

// Match implements the Argument interface, allowing check if the given value matches the expected type provided using TypedOutputArg function.
func (n typedOutValue) Match(v driver.Value) bool {
return n.TypeName == fmt.Sprintf("%T", v)
}

// TypeOutputArg can be used to simulate an output value passed back from the database, setting value based on the type.
// returnedOutValue must be a pointer to the value.
func TypedOutputArg(returnedOutValue interface{}) interface{} {
return typedOutValue{
TypeName: fmt.Sprintf("%T", returnedOutValue),
ReturnedOutValue: returnedOutValue,
}
}

func setOutputValues(currentArgs []driver.NamedValue, expectedArgs []driver.Value) {
for _, expectedArg := range expectedArgs {
if outVal, ok := expectedArg.(namedInOutValue); ok {
for _, currentArg := range currentArgs {
if currentArg.Name == outVal.Name {
if sqlOut, ok := currentArg.Value.(sql.Out); ok {
reflect.ValueOf(sqlOut.Dest).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))
}

break
}
}
}

if outVal, ok := expectedArg.(typedOutValue); ok {
for _, currentArg := range currentArgs {
if fmt.Sprintf("%T", currentArg.Value) == outVal.TypeName {
reflect.ValueOf(currentArg.Value).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))

break
}
}
}
}
}
31 changes: 31 additions & 0 deletions passthroughvalueconverter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package sqlmock

import (
"database/sql/driver"
"fmt"
)

type PassthroughValueConverter struct {
passthroughTypes []string
}

func NewPassthroughValueConverter(typeSamples ...interface{}) *PassthroughValueConverter {
c := &PassthroughValueConverter{}

for _, sampleValue := range typeSamples {
c.passthroughTypes = append(c.passthroughTypes, fmt.Sprintf("%T", sampleValue))
}

return c
}

func (c *PassthroughValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
valueType := fmt.Sprintf("%T", v)
for _, passthroughType := range c.passthroughTypes {
if valueType == passthroughType {
return v, nil
}
}

return driver.DefaultParameterConverter.ConvertValue(v)
}
5 changes: 5 additions & 0 deletions sqlmock_go18.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8

package sqlmock
Expand Down Expand Up @@ -250,6 +251,8 @@ func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery,
return expected, expected.err // mocked to return error
}

setOutputValues(args, expected.args)

if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
Expand Down Expand Up @@ -332,6 +335,8 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e
return expected, expected.err // mocked to return error
}

setOutputValues(args, expected.args)

if expected.result == nil {
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
Expand Down
Loading