Skip to content

Commit

Permalink
add output parameters support
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardos77 committed Jul 31, 2024
1 parent 6bed17c commit 6164ad6
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 0 deletions.
53 changes: 53 additions & 0 deletions examples/output/output.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package main

import (
"database/sql"
"fmt"
)

// this type emulate go-mssqldb's ReturnStatus type, used to get rc from SQL Server stored procedures
// https://github.com/microsoft/go-mssqldb/blob/main/mssql.go
type ReturnStatus int32

func execWithNamedOutputArgs(db *sql.DB, outputArg *string, inputOutputArg *string) (err error) {
_, err = db.Exec("EXEC spWithNamedOutputParameters",
sql.Named("outArg", sql.Out{Dest: outputArg}),
sql.Named("inoutArg", sql.Out{In: true, Dest: inputOutputArg}),
)
if err != nil {
return
}
return
}

func execWithTypedOutputArgs(db *sql.DB, rcArg *ReturnStatus) (err error) {
if _, err = db.Exec("EXEC spWithReturnCode", rcArg); err != nil {
return
}
return
}

func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mssql", "myconnectionstring")
if err != nil {
panic(err)
}
defer db.Close()

outputArg := ""
inputOutputArg := "abcInput"

if err = execWithNamedOutputArgs(db, &outputArg, &inputOutputArg); err != nil {
panic(err)
}

rcArg := new(ReturnStatus)
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
panic(err)
}

if _, err = fmt.Printf("outputArg: %s, inputOutputArg: %s, rcArg: %d", outputArg, inputOutputArg, *rcArg); err != nil {
panic(err)
}
}
75 changes: 75 additions & 0 deletions examples/output/output_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package main

import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
)

func TestNamedOutputArgs(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

inOutInputValue := "abcInput"
mock.ExpectExec("EXEC spWithNamedOutputParameters").
WithArgs(
sqlmock.NamedOutputArg("outArg", "123Output"),
sqlmock.NamedInputOutputArg("inoutArg", &inOutInputValue, "abcOutput"),
).
WillReturnResult(sqlmock.NewResult(1, 1))

// now we execute our method
outArg := ""
inoutArg := "abcInput"
if err = execWithNamedOutputArgs(db, &outArg, &inoutArg); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}

// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}

if outArg != "123Output" {
t.Errorf("unexpected outArg value")
}

if inoutArg != "abcOutput" {
t.Errorf("unexpected inoutArg value")
}
}

func TestTypedOutputArgs(t *testing.T) {
rcArg := new(ReturnStatus) // here we will store the return code

valueConverter := sqlmock.NewPassthroughValueConverter(rcArg) // we need this converter to bypass the default ValueConverter logic that alter original value's type
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(valueConverter))
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

rcFromSp := ReturnStatus(123) // simulate the return code from the stored procedure
mock.ExpectExec("EXEC spWithReturnCode").
WithArgs(
sqlmock.TypedOutputArg(&rcFromSp), // using this func we can provide the expected type and value
).
WillReturnResult(sqlmock.NewResult(1, 1))

// now we execute our method
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}

// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}

if *rcArg != 123 {
t.Errorf("unexpected rcArg value")
}
}
88 changes: 88 additions & 0 deletions output_args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package sqlmock

import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)

type namedInOutValue struct {
Name string
ExpectedInValue interface{}
ReturnedOutValue interface{}
In bool
}

// Match implements the Argument interface, allowing check if the given value matches the expected input value provided using NamedInputOutputArg function.
func (n namedInOutValue) Match(v driver.Value) bool {
out, ok := v.(sql.Out)

return ok && out.In == n.In && (!n.In || reflect.DeepEqual(out.Dest, n.ExpectedInValue))
}

// NamedInputArg can ben used to simulate an output value passed back from the database.
// returnedOutValue can be a value or a pointer to the value.
func NamedOutputArg(name string, returnedOutValue interface{}) interface{} {
return namedInOutValue{
Name: name,
ReturnedOutValue: returnedOutValue,
In: false,
}
}

// NamedInputOutputArg can be used to both check if expected input value is provided and to simulate an output value passed back from the database.
// expectedInValue must be a pointer to the value, returnedOutValue can be a value or a pointer to the value.
func NamedInputOutputArg(name string, expectedInValue interface{}, returnedOutValue interface{}) interface{} {
return namedInOutValue{
Name: name,
ExpectedInValue: expectedInValue,
ReturnedOutValue: returnedOutValue,
In: true,
}
}

type typedOutValue struct {
TypeName string
ReturnedOutValue interface{}
}

// Match implements the Argument interface, allowing check if the given value matches the expected type provided using TypedOutputArg function.
func (n typedOutValue) Match(v driver.Value) bool {
return n.TypeName == fmt.Sprintf("%T", v)
}

// TypeOutputArg can be used to simulate an output value passed back from the database, setting value based on the type.
// returnedOutValue must be a pointer to the value.
func TypedOutputArg(returnedOutValue interface{}) interface{} {
return typedOutValue{
TypeName: fmt.Sprintf("%T", returnedOutValue),
ReturnedOutValue: returnedOutValue,
}
}

func setOutputValues(currentArgs []driver.NamedValue, expectedArgs []driver.Value) {
for _, expectedArg := range expectedArgs {
if outVal, ok := expectedArg.(namedInOutValue); ok {
for _, currentArg := range currentArgs {
if currentArg.Name == outVal.Name {
if sqlOut, ok := currentArg.Value.(sql.Out); ok {
reflect.ValueOf(sqlOut.Dest).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))
}

break
}
}
}

if outVal, ok := expectedArg.(typedOutValue); ok {
for _, currentArg := range currentArgs {
if fmt.Sprintf("%T", currentArg.Value) == outVal.TypeName {
reflect.ValueOf(currentArg.Value).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))

break
}
}
}
}
}
31 changes: 31 additions & 0 deletions passthroughvalueconverter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package sqlmock

import (
"database/sql/driver"
"fmt"
)

type PassthroughValueConverter struct {
passthroughTypes []string
}

func NewPassthroughValueConverter(typeSamples ...interface{}) *PassthroughValueConverter {
c := &PassthroughValueConverter{}

for _, sampleValue := range typeSamples {
c.passthroughTypes = append(c.passthroughTypes, fmt.Sprintf("%T", sampleValue))
}

return c
}

func (c *PassthroughValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
valueType := fmt.Sprintf("%T", v)
for _, passthroughType := range c.passthroughTypes {
if valueType == passthroughType {
return v, nil
}
}

return driver.DefaultParameterConverter.ConvertValue(v)
}
5 changes: 5 additions & 0 deletions sqlmock_go18.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8

package sqlmock
Expand Down Expand Up @@ -250,6 +251,8 @@ func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery,
return expected, expected.err // mocked to return error
}

setOutputValues(args, expected.args)

if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
Expand Down Expand Up @@ -332,6 +335,8 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e
return expected, expected.err // mocked to return error
}

setOutputValues(args, expected.args)

if expected.result == nil {
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
Expand Down
Loading

0 comments on commit 6164ad6

Please sign in to comment.