diff --git a/examples/output/output.go b/examples/output/output.go new file mode 100644 index 0000000..fd2c046 --- /dev/null +++ b/examples/output/output.go @@ -0,0 +1,53 @@ +package main + +import ( + "database/sql" + "fmt" +) + +// this type emulate go-mssqldb's ReturnStatus type, used to get rc from SQL Server stored procedures +// https://github.com/microsoft/go-mssqldb/blob/main/mssql.go +type ReturnStatus int32 + +func execWithNamedOutputArgs(db *sql.DB, outputArg *string, inputOutputArg *string) (err error) { + _, err = db.Exec("EXEC spWithNamedOutputParameters", + sql.Named("outArg", sql.Out{Dest: outputArg}), + sql.Named("inoutArg", sql.Out{In: true, Dest: inputOutputArg}), + ) + if err != nil { + return + } + return +} + +func execWithTypedOutputArgs(db *sql.DB, rcArg *ReturnStatus) (err error) { + if _, err = db.Exec("EXEC spWithReturnCode", rcArg); err != nil { + return + } + return +} + +func main() { + // @NOTE: the real connection is not required for tests + db, err := sql.Open("mssql", "myconnectionstring") + if err != nil { + panic(err) + } + defer db.Close() + + outputArg := "" + inputOutputArg := "abcInput" + + if err = execWithNamedOutputArgs(db, &outputArg, &inputOutputArg); err != nil { + panic(err) + } + + rcArg := new(ReturnStatus) + if err = execWithTypedOutputArgs(db, rcArg); err != nil { + panic(err) + } + + if _, err = fmt.Printf("outputArg: %s, inputOutputArg: %s, rcArg: %d", outputArg, inputOutputArg, *rcArg); err != nil { + panic(err) + } +} diff --git a/examples/output/output_test.go b/examples/output/output_test.go new file mode 100644 index 0000000..548849a --- /dev/null +++ b/examples/output/output_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func TestNamedOutputArgs(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + inOutInputValue := "abcInput" + mock.ExpectExec("EXEC spWithNamedOutputParameters"). + WithArgs( + sqlmock.NamedOutputArg("outArg", "123Output"), + sqlmock.NamedInputOutputArg("inoutArg", &inOutInputValue, "abcOutput"), + ). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // now we execute our method + outArg := "" + inoutArg := "abcInput" + if err = execWithNamedOutputArgs(db, &outArg, &inoutArg); err != nil { + t.Errorf("error was not expected while updating stats: %s", err) + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + if outArg != "123Output" { + t.Errorf("unexpected outArg value") + } + + if inoutArg != "abcOutput" { + t.Errorf("unexpected inoutArg value") + } +} + +func TestTypedOutputArgs(t *testing.T) { + rcArg := new(ReturnStatus) // here we will store the return code + + valueConverter := sqlmock.NewPassthroughValueConverter(rcArg) // we need this converter to bypass the default ValueConverter logic that alter original value's type + db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(valueConverter)) + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rcFromSp := ReturnStatus(123) // simulate the return code from the stored procedure + mock.ExpectExec("EXEC spWithReturnCode"). + WithArgs( + sqlmock.TypedOutputArg(&rcFromSp), // using this func we can provide the expected type and value + ). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // now we execute our method + if err = execWithTypedOutputArgs(db, rcArg); err != nil { + t.Errorf("error was not expected while updating stats: %s", err) + } + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + if *rcArg != 123 { + t.Errorf("unexpected rcArg value") + } +} diff --git a/output_args.go b/output_args.go new file mode 100644 index 0000000..87040dd --- /dev/null +++ b/output_args.go @@ -0,0 +1,88 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +type namedInOutValue struct { + Name string + ExpectedInValue interface{} + ReturnedOutValue interface{} + In bool +} + +// Match implements the Argument interface, allowing check if the given value matches the expected input value provided using NamedInputOutputArg function. +func (n namedInOutValue) Match(v driver.Value) bool { + out, ok := v.(sql.Out) + + return ok && out.In == n.In && (!n.In || reflect.DeepEqual(out.Dest, n.ExpectedInValue)) +} + +// NamedInputArg can ben used to simulate an output value passed back from the database. +// returnedOutValue can be a value or a pointer to the value. +func NamedOutputArg(name string, returnedOutValue interface{}) interface{} { + return namedInOutValue{ + Name: name, + ReturnedOutValue: returnedOutValue, + In: false, + } +} + +// NamedInputOutputArg can be used to both check if expected input value is provided and to simulate an output value passed back from the database. +// expectedInValue must be a pointer to the value, returnedOutValue can be a value or a pointer to the value. +func NamedInputOutputArg(name string, expectedInValue interface{}, returnedOutValue interface{}) interface{} { + return namedInOutValue{ + Name: name, + ExpectedInValue: expectedInValue, + ReturnedOutValue: returnedOutValue, + In: true, + } +} + +type typedOutValue struct { + TypeName string + ReturnedOutValue interface{} +} + +// Match implements the Argument interface, allowing check if the given value matches the expected type provided using TypedOutputArg function. +func (n typedOutValue) Match(v driver.Value) bool { + return n.TypeName == fmt.Sprintf("%T", v) +} + +// TypeOutputArg can be used to simulate an output value passed back from the database, setting value based on the type. +// returnedOutValue must be a pointer to the value. +func TypedOutputArg(returnedOutValue interface{}) interface{} { + return typedOutValue{ + TypeName: fmt.Sprintf("%T", returnedOutValue), + ReturnedOutValue: returnedOutValue, + } +} + +func setOutputValues(currentArgs []driver.NamedValue, expectedArgs []driver.Value) { + for _, expectedArg := range expectedArgs { + if outVal, ok := expectedArg.(namedInOutValue); ok { + for _, currentArg := range currentArgs { + if currentArg.Name == outVal.Name { + if sqlOut, ok := currentArg.Value.(sql.Out); ok { + reflect.ValueOf(sqlOut.Dest).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue))) + } + + break + } + } + } + + if outVal, ok := expectedArg.(typedOutValue); ok { + for _, currentArg := range currentArgs { + if fmt.Sprintf("%T", currentArg.Value) == outVal.TypeName { + reflect.ValueOf(currentArg.Value).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue))) + + break + } + } + } + } +} diff --git a/passthroughvalueconverter.go b/passthroughvalueconverter.go new file mode 100644 index 0000000..b7c0939 --- /dev/null +++ b/passthroughvalueconverter.go @@ -0,0 +1,31 @@ +package sqlmock + +import ( + "database/sql/driver" + "fmt" +) + +type PassthroughValueConverter struct { + passthroughTypes []string +} + +func NewPassthroughValueConverter(typeSamples ...interface{}) *PassthroughValueConverter { + c := &PassthroughValueConverter{} + + for _, sampleValue := range typeSamples { + c.passthroughTypes = append(c.passthroughTypes, fmt.Sprintf("%T", sampleValue)) + } + + return c +} + +func (c *PassthroughValueConverter) ConvertValue(v interface{}) (driver.Value, error) { + valueType := fmt.Sprintf("%T", v) + for _, passthroughType := range c.passthroughTypes { + if valueType == passthroughType { + return v, nil + } + } + + return driver.DefaultParameterConverter.ConvertValue(v) +} diff --git a/sqlmock_go18.go b/sqlmock_go18.go index f268900..28abe0d 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -250,6 +251,8 @@ func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, return expected, expected.err // mocked to return error } + setOutputValues(args, expected.args) + if expected.rows == nil { return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } @@ -332,6 +335,8 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e return expected, expected.err // mocked to return error } + setOutputValues(args, expected.args) + if expected.result == nil { return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) } diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index 6267f38..248271a 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -132,6 +132,172 @@ func TestContextExecWithNamedArg(t *testing.T) { } } +func TestContextExecWithNamedOutputArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + expectedOut1 := 10 + expectedOut2 := 20 + mock.ExpectExec("EXEC ProcWithOutParam"). + WithArgs( + sql.Named("id", 5), + NamedOutputArg("out1", expectedOut1), // out1 is int + NamedOutputArg("out2", &expectedOut2), // out2 is *int + ). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + actualOut1 := new(int) + actualOut2 := new(int) + _, err = db.ExecContext( + ctx, + "EXEC ProcWithOutParam", + sql.Named("id", 5), + sql.Named("out1", sql.Out{Dest: actualOut1, In: false}), + sql.Named("out2", sql.Out{Dest: actualOut2, In: false}), + ) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + if expectedOut1 != *actualOut1 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut1, actualOut1) + } + + if expectedOut2 != *actualOut2 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut2, actualOut2) + } + +} + +func TestContextExecWithNamedInputOutputArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + expectedIn1 := 11 + expectedOut1 := 12 + expectedIn2 := 21 + expectedOut2 := 22 + mock.ExpectExec("EXEC ProcWithInOutParam"). + WithArgs( + sql.Named("id", 5), + NamedInputOutputArg("inout1", &expectedIn1, expectedOut1), // out1 is int + NamedInputOutputArg("inout2", &expectedIn2, &expectedOut2), // out2 is *int + ). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + var actualInOut1 *int = &expectedIn1 + var actualInOut2 *int = &expectedIn2 + _, err = db.ExecContext( + ctx, + "EXEC ProcWithInOutParam", + sql.Named("id", 5), + sql.Named("inout1", sql.Out{Dest: actualInOut1, In: true}), + sql.Named("inout2", sql.Out{Dest: actualInOut2, In: true}), + ) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + if expectedOut1 != *actualInOut1 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut1, actualInOut1) + } + + if expectedOut2 != *actualInOut2 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut2, actualInOut2) + } +} + +type MockReturnStatus int32 +type MockOtherOutputType int32 + +func TestContextExecWithTypedOutputArg(t *testing.T) { + t.Parallel() + + rs1 := new(MockReturnStatus) + rs2 := new(MockOtherOutputType) + converter := NewPassthroughValueConverter(rs1, rs2) + db, mock, err := New(ValueConverterOption(converter)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + expectedOut1 := MockReturnStatus(10) + expectedOut2 := MockOtherOutputType(20) + mock.ExpectExec("EXEC ProcWithTypedRCParam"). + WithArgs( + sql.Named("id", 5), + TypedOutputArg(&expectedOut1), + TypedOutputArg(&expectedOut2), + ). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + var actualOut1 MockReturnStatus + var actualOut2 MockOtherOutputType + _, err = db.ExecContext( + ctx, + "EXEC ProcWithTypedRCParam", + sql.Named("id", 5), + &actualOut1, + &actualOut2, + ) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + if expectedOut1 != actualOut1 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut1, actualOut1) + } + + if expectedOut2 != actualOut2 { + t.Errorf("output param value was not expected, was expecting %v, but got %v", expectedOut2, actualOut2) + } +} + func TestContextExec(t *testing.T) { t.Parallel() db, mock, err := New()