diff --git a/pgserver/iter.go b/pgserver/iter.go index 4ec69d6..e83b160 100644 --- a/pgserver/iter.go +++ b/pgserver/iter.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "fmt" "io" + "math/big" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -87,6 +88,7 @@ type SqlRowIter struct { decimals []int lists []int + hugeInts []int } func NewSqlRowIter(rows *stdsql.Rows, schema sql.Schema) (*SqlRowIter, error) { @@ -116,7 +118,14 @@ func NewSqlRowIter(rows *stdsql.Rows, schema sql.Schema) (*SqlRowIter, error) { } } - iter := &SqlRowIter{rows, columns, schema, buf, ptrs, decimals, lists} + var hugeInts []int + for i, t := range columns { + if t.DatabaseTypeName() == "HUGEINT" { + hugeInts = append(hugeInts, i) + } + } + + iter := &SqlRowIter{rows, columns, schema, buf, ptrs, decimals, lists, hugeInts} if logrus.GetLevel() >= logrus.DebugLevel { logrus.Debugf("New " + iter.String() + "\n") } @@ -197,6 +206,21 @@ func (iter *SqlRowIter) Next(ctx *sql.Context) (sql.Row, error) { iter.buffer[idx] = pgtype.FlatArray[any](list) } + for _, idx := range iter.hugeInts { + switch v := iter.buffer[idx].(type) { + case nil: + continue + case *big.Int: + var n pgtype.Numeric + if err := n.Scan(v.String()); err != nil { + return nil, err + } + iter.buffer[idx] = n + default: + return nil, fmt.Errorf("unexpected type %T for big.Int value", v) + } + } + // Prune or fill the values to match the schema width := len(iter.schema) // the desired width if width == 0 {