diff --git a/HISTORY.md b/HISTORY.md index a088ca84..00adaf95 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,9 @@ +## v7.1.0 + +* [FIXED] Embedded pointers with property names that duplicate parent struct properties. [#23](https://github.com/doug-martin/goqu/issues/23) +* [FIXED] Can't scan values using []byte or []string [#90](https://github.com/doug-martin/goqu/issues/90) + * When a slice that is `*sql.RawBytes`, `*[]byte` or `sql.Scanner` no errors will be returned. + ## v7.0.1 * Fix issue where structs with pointer fields where not set properly [#86](https://github.com/doug-martin/goqu/pull/86) and [#89](https://github.com/doug-martin/goqu/pull/89) - [@efureev](https://github.com/efureev) diff --git a/exec/query_executor.go b/exec/query_executor.go index 87a895f0..d759f19d 100644 --- a/exec/query_executor.go +++ b/exec/query_executor.go @@ -203,13 +203,19 @@ func (q QueryExecutor) ScanValContext(ctx context.Context, i interface{}) (bool, } val = reflect.Indirect(val) if util.IsSlice(val.Kind()) { - return false, errScanValNonSlice + switch i.(type) { + case *sql.RawBytes: // do nothing + case *[]byte: // do nothing + case sql.Scanner: // do nothing + default: + return false, errScanValNonSlice + } } rows, err := q.QueryContext(ctx) if err != nil { return false, err } - return NewScanner(rows).ScanVals(i) + return NewScanner(rows).ScanVal(i) } func (q QueryExecutor) rowsScanner(ctx context.Context) (Scanner, error) { diff --git a/exec/scanner.go b/exec/scanner.go index 18830d43..e15feab2 100644 --- a/exec/scanner.go +++ b/exec/scanner.go @@ -13,6 +13,7 @@ type ( Scanner interface { ScanStructs(i interface{}) (bool, error) ScanVals(i interface{}) (bool, error) + ScanVal(i interface{}) (found bool, err error) } scanner struct { rows *sql.Rows @@ -76,28 +77,35 @@ func (q *scanner) ScanVals(i interface{}) (found bool, err error) { defer q.rows.Close() val := reflect.Indirect(reflect.ValueOf(i)) t, _, isSliceOfPointers := util.GetTypeInfo(i, val) - switch val.Kind() { - case reflect.Slice: - for q.rows.Next() { - found = true - row := reflect.New(t) - if err = q.rows.Scan(row.Interface()); err != nil { - return found, err - } - if isSliceOfPointers { - val.Set(reflect.Append(val, row)) - } else { - val.Set(reflect.Append(val, reflect.Indirect(row))) - } + for q.rows.Next() { + found = true + row := reflect.New(t) + if err = q.rows.Scan(row.Interface()); err != nil { + return found, err } - default: - for q.rows.Next() { - found = true - if err = q.rows.Scan(i); err != nil { - return false, err - } + if isSliceOfPointers { + val.Set(reflect.Append(val, row)) + } else { + val.Set(reflect.Append(val, reflect.Indirect(row))) } + } + return found, q.rows.Err() +} +// This will execute the SQL and append results to the slice. +// var ids []uint32 +// if err := From("test").Select("id").ScanVals(&ids); err != nil{ +// panic(err.Error() +// } +// +// i: Takes a pointer to a slice of primitive values. +func (q *scanner) ScanVal(i interface{}) (found bool, err error) { + defer q.rows.Close() + for q.rows.Next() { + found = true + if err = q.rows.Scan(i); err != nil { + return false, err + } } return found, q.rows.Err() } diff --git a/exec/scanner_test.go b/exec/scanner_test.go index 29be04c2..d8c14077 100644 --- a/exec/scanner_test.go +++ b/exec/scanner_test.go @@ -3,6 +3,7 @@ package exec import ( "context" "database/sql" + "encoding/json" "fmt" "testing" @@ -38,6 +39,18 @@ type TestEmbeddedPtrCrudActionItem struct { Age int64 `db:"age"` } +type TestComposedDuplicateFieldsItem struct { + TestCrudActionItem + Address string `db:"other_address"` + Name string `db:"other_name"` +} + +type TestComposedPointerDuplicateFieldsItem struct { + *TestCrudActionItem + Address string `db:"other_address"` + Name string `db:"other_name"` +} + var ( testAddr1 = "111 Test Addr" testAddr2 = "211 Test Addr" @@ -219,6 +232,64 @@ func (cet *crudExecTest) TestScanStructs_pointersWithEmbeddedStruct() { }, composed) } +func (cet *crudExecTest) TestScanStructs_pointersWithEmbeddedStructDuplicateFields() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name", "other_address", "other_name"}). + FromCSVString("111 Test Addr,Test1,111 Test Addr Other,Test1 Other\n211 Test Addr,Test2,211 Test Addr Other,Test2 Other")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) + + var composed []*TestComposedDuplicateFieldsItem + assert.NoError(t, e.ScanStructs(&composed)) + assert.Equal(t, []*TestComposedDuplicateFieldsItem{ + { + TestCrudActionItem: TestCrudActionItem{Address: "111 Test Addr", Name: "Test1"}, + Address: "111 Test Addr Other", + Name: "Test1 Other", + }, + { + TestCrudActionItem: TestCrudActionItem{Address: "211 Test Addr", Name: "Test2"}, + Address: "211 Test Addr Other", + Name: "Test2 Other", + }, + }, composed) +} + +func (cet *crudExecTest) TestScanStructs_pointersWithEmbeddedPointerDuplicateFields() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name", "other_address", "other_name"}). + FromCSVString("111 Test Addr,Test1,111 Test Addr Other,Test1 Other\n211 Test Addr,Test2,211 Test Addr Other,Test2 Other")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) + + var composed []*TestComposedPointerDuplicateFieldsItem + assert.NoError(t, e.ScanStructs(&composed)) + assert.Equal(t, []*TestComposedPointerDuplicateFieldsItem{ + { + TestCrudActionItem: &TestCrudActionItem{Address: "111 Test Addr", Name: "Test1"}, + Address: "111 Test Addr Other", + Name: "Test1 Other", + }, + { + TestCrudActionItem: &TestCrudActionItem{Address: "211 Test Addr", Name: "Test2"}, + Address: "211 Test Addr Other", + Name: "Test2 Other", + }, + }, composed) +} + func (cet *crudExecTest) TestScanStructs_withEmbeddedStructPointer() { t := cet.T() mDb, mock, err := sqlmock.New() @@ -637,6 +708,81 @@ func (cet *crudExecTest) TestScanVal() { assert.Equal(t, ptrID, int64(1)) } +func (cet *crudExecTest) TestScanVal_withByteSlice() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"name"}).FromCSVString("byte slice result")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "name" FROM "items"`) + + var bytes []byte + found, err := e.ScanVal(bytes) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bytes) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, []byte("byte slice result"), bytes) +} + +func (cet *crudExecTest) TestScanVal_withRawBytes() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"name"}).FromCSVString("byte slice result")) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "name" FROM "items"`) + + var bytes sql.RawBytes + found, err := e.ScanVal(bytes) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bytes) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, sql.RawBytes("byte slice result"), bytes) +} + +type JSONBoolArray []bool + +func (b *JSONBoolArray) Scan(src interface{}) error { + return json.Unmarshal(src.([]byte), b) +} + +func (cet *crudExecTest) TestScanVal_withValuerSlice() { + t := cet.T() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + + mock.ExpectQuery(`SELECT "bools" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"bools"}).FromCSVString(`"[true, false, true]"`)) + + db := newMockDb(mDb) + e := newQueryExecutor(db, nil, `SELECT "bools" FROM "items"`) + + var bools JSONBoolArray + found, err := e.ScanVal(bools) + assert.Equal(t, errScanValPointer, err) + assert.False(t, found) + + found, err = e.ScanVal(&bools) + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, JSONBoolArray{true, false, true}, bools) +} + func TestCrudExecSuite(t *testing.T) { suite.Run(t, new(crudExecTest)) } diff --git a/internal/util/reflect.go b/internal/util/reflect.go index d78efeac..40f0af66 100644 --- a/internal/util/reflect.go +++ b/internal/util/reflect.go @@ -11,7 +11,7 @@ import ( type ( ColumnData struct { ColumnName string - FieldName string + FieldIndex []int GoType reflect.Type } ColumnMap map[string]ColumnData @@ -121,7 +121,7 @@ func assignRowData(row reflect.Value, rd rowData, cm ColumnMap) { src, ok := rd[name] if ok { srcVal := reflect.ValueOf(src) - f := row.FieldByName(data.FieldName) + f := row.FieldByIndex(data.FieldIndex) if f.Kind() == reflect.Ptr { f.Set(srcVal) } else { @@ -153,21 +153,21 @@ func GetColumnMap(i interface{}) (ColumnMap, error) { structMapCacheLock.Lock() defer structMapCacheLock.Unlock() if _, ok := structMapCache[t]; !ok { - structMapCache[t] = createColumnMap(t) + structMapCache[t] = createColumnMap(t, []int{}) } return structMapCache[t], nil } -func createColumnMap(t reflect.Type) ColumnMap { +func createColumnMap(t reflect.Type, fieldIndex []int) ColumnMap { cm, n := ColumnMap{}, t.NumField() var subColMaps []ColumnMap for i := 0; i < n; i++ { f := t.Field(i) if f.Anonymous && (f.Type.Kind() == reflect.Struct || f.Type.Kind() == reflect.Ptr) { if f.Type.Kind() == reflect.Ptr { - subColMaps = append(subColMaps, createColumnMap(f.Type.Elem())) + subColMaps = append(subColMaps, createColumnMap(f.Type.Elem(), append(fieldIndex, f.Index...))) } else { - subColMaps = append(subColMaps, createColumnMap(f.Type)) + subColMaps = append(subColMaps, createColumnMap(f.Type, append(fieldIndex, f.Index...))) } } else { columnName := f.Tag.Get("db") @@ -177,7 +177,7 @@ func createColumnMap(t reflect.Type) ColumnMap { if columnName != "-" { cm[columnName] = ColumnData{ ColumnName: columnName, - FieldName: f.Name, + FieldIndex: append(fieldIndex, f.Index...), GoType: f.Type, } } diff --git a/internal/util/reflect_test.go b/internal/util/reflect_test.go index 49bc9d2b..1d60851d 100644 --- a/internal/util/reflect_test.go +++ b/internal/util/reflect_test.go @@ -722,10 +722,10 @@ func (rt *reflectTest) TestGetColumnMap_withStruct() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "str": {ColumnName: "str", FieldName: "Str", GoType: reflect.TypeOf("")}, - "int": {ColumnName: "int", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "bool": {ColumnName: "bool", FieldName: "Bool", GoType: reflect.TypeOf(true)}, - "valuer": {ColumnName: "valuer", FieldName: "Valuer", GoType: reflect.TypeOf(&sql.NullString{})}, + "str": {ColumnName: "str", FieldIndex: []int{0}, GoType: reflect.TypeOf("")}, + "int": {ColumnName: "int", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "bool": {ColumnName: "bool", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, + "valuer": {ColumnName: "valuer", FieldIndex: []int{3}, GoType: reflect.TypeOf(&sql.NullString{})}, }, cm) } @@ -742,10 +742,10 @@ func (rt *reflectTest) TestGetColumnMap_withStructWithTag() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "s": {ColumnName: "s", FieldName: "Str", GoType: reflect.TypeOf("")}, - "i": {ColumnName: "i", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "b": {ColumnName: "b", FieldName: "Bool", GoType: reflect.TypeOf(true)}, - "v": {ColumnName: "v", FieldName: "Valuer", GoType: reflect.TypeOf(&sql.NullString{})}, + "s": {ColumnName: "s", FieldIndex: []int{0}, GoType: reflect.TypeOf("")}, + "i": {ColumnName: "i", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "b": {ColumnName: "b", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, + "v": {ColumnName: "v", FieldIndex: []int{3}, GoType: reflect.TypeOf(&sql.NullString{})}, }, cm) } @@ -762,9 +762,9 @@ func (rt *reflectTest) TestGetColumnMap_withStructWithTransientFields() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "str": {ColumnName: "str", FieldName: "Str", GoType: reflect.TypeOf("")}, - "int": {ColumnName: "int", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "bool": {ColumnName: "bool", FieldName: "Bool", GoType: reflect.TypeOf(true)}, + "str": {ColumnName: "str", FieldIndex: []int{0}, GoType: reflect.TypeOf("")}, + "int": {ColumnName: "int", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "bool": {ColumnName: "bool", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, }, cm) } @@ -781,10 +781,10 @@ func (rt *reflectTest) TestGetColumnMap_withSliceOfStructs() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "str": {ColumnName: "str", FieldName: "Str", GoType: reflect.TypeOf("")}, - "int": {ColumnName: "int", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "bool": {ColumnName: "bool", FieldName: "Bool", GoType: reflect.TypeOf(true)}, - "valuer": {ColumnName: "valuer", FieldName: "Valuer", GoType: reflect.TypeOf(&sql.NullString{})}, + "str": {ColumnName: "str", FieldIndex: []int{0}, GoType: reflect.TypeOf("")}, + "int": {ColumnName: "int", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "bool": {ColumnName: "bool", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, + "valuer": {ColumnName: "valuer", FieldIndex: []int{3}, GoType: reflect.TypeOf(&sql.NullString{})}, }, cm) } @@ -813,10 +813,10 @@ func (rt *reflectTest) TestGetColumnMap_withStructWithEmbeddedStruct() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "str": {ColumnName: "str", FieldName: "Str", GoType: reflect.TypeOf("")}, - "int": {ColumnName: "int", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "bool": {ColumnName: "bool", FieldName: "Bool", GoType: reflect.TypeOf(true)}, - "valuer": {ColumnName: "valuer", FieldName: "Valuer", GoType: reflect.TypeOf(&sql.NullString{})}, + "str": {ColumnName: "str", FieldIndex: []int{0, 0}, GoType: reflect.TypeOf("")}, + "int": {ColumnName: "int", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "bool": {ColumnName: "bool", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, + "valuer": {ColumnName: "valuer", FieldIndex: []int{3}, GoType: reflect.TypeOf(&sql.NullString{})}, }, cm) } @@ -836,10 +836,10 @@ func (rt *reflectTest) TestGetColumnMap_withStructWithEmbeddedStructPointer() { cm, err := GetColumnMap(&ts) assert.NoError(t, err) assert.Equal(t, ColumnMap{ - "str": {ColumnName: "str", FieldName: "Str", GoType: reflect.TypeOf("")}, - "int": {ColumnName: "int", FieldName: "Int", GoType: reflect.TypeOf(int64(1))}, - "bool": {ColumnName: "bool", FieldName: "Bool", GoType: reflect.TypeOf(true)}, - "valuer": {ColumnName: "valuer", FieldName: "Valuer", GoType: reflect.TypeOf(&sql.NullString{})}, + "str": {ColumnName: "str", FieldIndex: []int{0, 0}, GoType: reflect.TypeOf("")}, + "int": {ColumnName: "int", FieldIndex: []int{1}, GoType: reflect.TypeOf(int64(1))}, + "bool": {ColumnName: "bool", FieldIndex: []int{2}, GoType: reflect.TypeOf(true)}, + "valuer": {ColumnName: "valuer", FieldIndex: []int{3}, GoType: reflect.TypeOf(&sql.NullString{})}, }, cm) }