Skip to content

Commit

Permalink
Add Column Metadata #152
Browse files Browse the repository at this point in the history
  • Loading branch information
alespour committed Jun 22, 2020
1 parent 3b99664 commit 3b533ba
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ It only asserts that argument is of `time.Time` type.

## Change Log

- **2019-04-06** - added functionality to mock a sql MetaData request
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
Expand Down
77 changes: 77 additions & 0 deletions column.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package sqlmock

import "reflect"

// Column is a mocked column Metadata for rows.ColumnTypes()
type Column struct {
name string
dbType string
nullable bool
nullableOk bool
length int64
lengthOk bool
precision int64
scale int64
psOk bool
scanType reflect.Type
}

func (c *Column) Name() string {
return c.name
}

func (c *Column) DbType() string {
return c.dbType
}

func (c *Column) IsNullable() (bool, bool) {
return c.nullable, c.nullableOk
}

func (c *Column) Length() (int64, bool) {
return c.length, c.lengthOk
}

func (c *Column) PrecisionScale() (int64, int64, bool) {
return c.precision, c.scale, c.psOk
}

func (c *Column) ScanType() reflect.Type {
return c.scanType
}

// NewColumn returns a Column with specified name
func NewColumn(name string) *Column {
return &Column{
name: name,
}
}

// Nullable returns the column with nullable metadata set
func (c *Column) Nullable(nullable bool) *Column {
c.nullable = nullable
c.nullableOk = true
return c
}

// OfType returns the column with type metadata set
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
c.dbType = dbType
c.scanType = reflect.TypeOf(sampleValue)
return c
}

// WithLength returns the column with length metadata set.
func (c *Column) WithLength(length int64) *Column {
c.length = length
c.lengthOk = true
return c
}

// WithPrecisionAndScale returns the column with precision and scale metadata set.
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
c.precision = precision
c.scale = scale
c.psOk = true
return c
}
63 changes: 63 additions & 0 deletions column_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package sqlmock

import (
"reflect"
"testing"
"time"
)

func TestColumn(t *testing.T) {
now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
column3 := NewColumn("when").OfType("TIMESTAMP", now)

if column1.ScanType().Kind() != reflect.String {
t.Errorf("string scanType mismatch: %v", column1.ScanType())
}
if column2.ScanType().Kind() != reflect.Float64 {
t.Errorf("float scanType mismatch: %v", column2.ScanType())
}
if column3.ScanType() != reflect.TypeOf(time.Time{}) {
t.Errorf("time scanType mismatch: %v", column3.ScanType())
}

nullable, ok := column1.IsNullable()
if !nullable || !ok {
t.Errorf("'test' column should be nullable")
}
nullable, ok = column2.IsNullable()
if nullable || !ok {
t.Errorf("'number' column should not be nullable")
}
nullable, ok = column3.IsNullable()
if ok {
t.Errorf("'when' column nullability should be unknown")
}

length, ok := column1.Length()
if length != 100 || !ok {
t.Errorf("'test' column wrong length")
}
length, ok = column2.Length()
if ok {
t.Errorf("'number' column is not of variable length type")
}
length, ok = column3.Length()
if ok {
t.Errorf("'when' column is not of variable length type")
}

_, _, ok = column1.PrecisionScale()
if ok {
t.Errorf("'test' column not applicable")
}
precision, scale, ok := column2.PrecisionScale()
if precision != 10 || scale != 4 || !ok {
t.Errorf("'number' column not applicable")
}
_, _, ok = column3.PrecisionScale()
if ok {
t.Errorf("'when' column not applicable")
}
}
10 changes: 9 additions & 1 deletion expectations_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@ import (
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
defs := 0
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
if r.def != nil {
defs++
}
}
if defs > 0 && defs == len(sets) {
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
} else {
e.rows = &rowSets{sets: sets, ex: e}
}
e.rows = &rowSets{sets: sets, ex: e}
return e
}

Expand Down
1 change: 1 addition & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (rs *rowSets) invalidateRaw() {
type Rows struct {
converter driver.ValueConverter
cols []string
def []*Column
rows [][]driver.Value
pos int
nextErr map[int]error
Expand Down
56 changes: 55 additions & 1 deletion rows_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

package sqlmock

import "io"
import (
"database/sql/driver"
"io"
"reflect"
)

// Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool {
Expand All @@ -18,3 +22,53 @@ func (rs *rowSets) NextResultSet() error {
rs.pos++
return nil
}

// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
type rowSetsWithDefinition struct {
*rowSets
}

// Implement the "RowsColumnTypeDatabaseTypeName" interface
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
return rs.getDefinition(index).DbType()
}

// Implement the "RowsColumnTypeLength" interface
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
return rs.getDefinition(index).Length()
}

// Implement the "RowsColumnTypeNullable" interface
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
return rs.getDefinition(index).IsNullable()
}

// Implement the "RowsColumnTypePrecisionScale" interface
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return rs.getDefinition(index).PrecisionScale()
}

// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
return rs.getDefinition(index).ScanType()
}

// return column definition from current set metadata
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
return rs.sets[rs.pos].def[index]
}

// NewRowsWithColumnDefinition return rows with columns metadata
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
cols := make([]string, len(columns))
for i, column := range columns {
cols[i] = column.Name()
}

return &Rows{
cols: cols,
def: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
Loading

0 comments on commit 3b533ba

Please sign in to comment.