From 3151d9addea0745947664d1d1a3c54d051b84313 Mon Sep 17 00:00:00 2001 From: doug-martin Date: Wed, 10 Jul 2019 19:41:21 -0500 Subject: [PATCH] Fix for #90 * When a slice that is `*sql.RawBytes`, `*[]byte` or `sql.Scanner` no errors will be returned. --- HISTORY.md | 2 ++ exec/query_executor.go | 10 ++++-- exec/scanner.go | 46 ++++++++++++++----------- exec/scanner_test.go | 76 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 21 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index e791c9e6..00adaf95 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,8 @@ ## v7.1.0 * [FIXED] Embedded pointers with property names that duplicate parent struct properties. [#23](https://github.com/doug-martin/goqu/issues/23) +* [FIXED] Can't scan values using []byte or []string [#90](https://github.com/doug-martin/goqu/issues/90) + * When a slice that is `*sql.RawBytes`, `*[]byte` or `sql.Scanner` no errors will be returned. ## v7.0.1 diff --git a/exec/query_executor.go b/exec/query_executor.go index 87a895f0..d759f19d 100644 --- a/exec/query_executor.go +++ b/exec/query_executor.go @@ -203,13 +203,19 @@ func (q QueryExecutor) ScanValContext(ctx context.Context, i interface{}) (bool, } val = reflect.Indirect(val) if util.IsSlice(val.Kind()) { - return false, errScanValNonSlice + switch i.(type) { + case *sql.RawBytes: // do nothing + case *[]byte: // do nothing + case sql.Scanner: // do nothing + default: + return false, errScanValNonSlice + } } rows, err := q.QueryContext(ctx) if err != nil { return false, err } - return NewScanner(rows).ScanVals(i) + return NewScanner(rows).ScanVal(i) } func (q QueryExecutor) rowsScanner(ctx context.Context) (Scanner, error) { diff --git a/exec/scanner.go b/exec/scanner.go index 18830d43..e15feab2 100644 --- a/exec/scanner.go +++ b/exec/scanner.go @@ -13,6 +13,7 @@ type ( Scanner interface { ScanStructs(i interface{}) (bool, error) ScanVals(i interface{}) (bool, error) + ScanVal(i interface{}) (found bool, err error) } scanner struct { rows *sql.Rows @@ -76,28 +77,35 @@ func (q *scanner) ScanVals(i interface{}) (found bool, err error) { defer q.rows.Close() val := reflect.Indirect(reflect.ValueOf(i)) t, _, isSliceOfPointers := util.GetTypeInfo(i, val) - switch val.Kind() { - case reflect.Slice: - for q.rows.Next() { - found = true - row := reflect.New(t) - if err = q.rows.Scan(row.Interface()); err != nil { - return found, err - } - if isSliceOfPointers { - val.Set(reflect.Append(val, row)) - } else { - val.Set(reflect.Append(val, reflect.Indirect(row))) - } + for q.rows.Next() { + found = true + row := reflect.New(t) + if err = q.rows.Scan(row.Interface()); err != nil { + return found, err } - default: - for q.rows.Next() { - found = true - if err = q.rows.Scan(i); err != nil { - return false, err - } + if isSliceOfPointers { + val.Set(reflect.Append(val, row)) + } else { + val.Set(reflect.Append(val, reflect.Indirect(row))) } + } + return found, q.rows.Err() +} +// This will execute the SQL and append results to the slice. +// var ids []uint32 +// if err := From("test").Select("id").ScanVals(&ids); err != nil{ +// panic(err.Error() +// } +// +// i: Takes a pointer to a slice of primitive values. +func (q *scanner) ScanVal(i interface{}) (found bool, err error) { + defer q.rows.Close() + for q.rows.Next() { + found = true + if err = q.rows.Scan(i); err != nil { + return false, err + } } return found, q.rows.Err() } diff --git a/exec/scanner_test.go b/exec/scanner_test.go index b4423469..d8c14077 100644 --- a/exec/scanner_test.go +++ b/exec/scanner_test.go @@ -3,6 +3,7 @@ package exec import ( "context" "database/sql" + "encoding/json" "fmt" "testing" @@ -707,6 +708,81 @@ func (cet *crudExecTest) TestScanVal() { assert.Equal(t, ptrID, int64(1)) } +func (cet *crudExecTest) TestScanVal_withByteSlice() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"name"}).FromCSVString("byte slice result")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "name" FROM "items"`) + + var bytes []byte + found, err := e.ScanVal(bytes) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bytes) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, []byte("byte slice result"), bytes) +} + +func (cet *crudExecTest) TestScanVal_withRawBytes() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"name"}).FromCSVString("byte slice result")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "name" FROM "items"`) + + var bytes sql.RawBytes + found, err := e.ScanVal(bytes) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bytes) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, sql.RawBytes("byte slice result"), bytes) +} + +type JSONBoolArray []bool + +func (b *JSONBoolArray) Scan(src interface{}) error { + return json.Unmarshal(src.([]byte), b) +} + +func (cet *crudExecTest) TestScanVal_withValuerSlice() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "bools" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"bools"}).FromCSVString(`"[true, false, true]"`)) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "bools" FROM "items"`) + + var bools JSONBoolArray + found, err := e.ScanVal(bools) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bools) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, JSONBoolArray{true, false, true}, bools) +} + func TestCrudExecSuite(t *testing.T) { suite.Run(t, new(crudExecTest)) }