Skip to content

Commit

Permalink
Replace *big.Rat with shopspring decimal.Decimal (#1491)
Browse files Browse the repository at this point in the history
Rational numbers are awkward decimals,
using a decimal type for decimal data has a few benefits:
1. perf-wise decimal representation is a *big.Int & an int32 exponent
    (same as pgtype.Numeric), instead of 2 *big.Ints
2. decimal has obvious conversion to/from decimal strings,
    whereas *big.Rat can be wonky going through FloatString
4. with scripting most people will want to use their decimal values as decimals,
    not decimal rationals

A side effect of this is that we now convert NaN/inf/-inf numerics to null instead of NaN.
This would only be usable for pg<->pg, only causing trouble with other peers,
& even then we'd want inf/-inf to be properly translated

Fixes #1175
  • Loading branch information
serprex authored Mar 16, 2024
1 parent 78a8cdb commit 1110158
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 249 deletions.
4 changes: 2 additions & 2 deletions flow/connectors/postgres/qrep_query_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"bytes"
"context"
"fmt"
"math/big"
"testing"
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
)
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestAllDataTypes(t *testing.T) {
}

expectedNumeric := "123.456"
actualNumeric := record[10].Value.(*big.Rat).FloatString(3)
actualNumeric := record[10].Value.(decimal.Decimal).String()
if actualNumeric != expectedNumeric {
t.Fatalf("expected %v, got %v", expectedNumeric, actualNumeric)
}
Expand Down
46 changes: 14 additions & 32 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"net/netip"
"strings"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
)

var big10 = big.NewInt(10)

func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
switch recvOID {
case pgtype.BoolOID:
Expand Down Expand Up @@ -217,8 +215,7 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
val := qvalue.QValue{}

if value == nil {
val = qvalue.QValue{Kind: qvalueKind, Value: nil}
return val, nil
return qvalue.QValue{Kind: qvalueKind, Value: nil}, nil
}

switch qvalueKind {
Expand Down Expand Up @@ -341,11 +338,11 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindNumeric:
numVal := value.(pgtype.Numeric)
if numVal.Valid {
rat, err := numericToRat(&numVal)
num, err := numericToDecimal(numVal)
if err != nil {
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err)
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err)
}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: num}
}
case qvalue.QValueKindArrayFloat32:
return convertToArray[float32](qvalueKind, value)
Expand Down Expand Up @@ -388,31 +385,16 @@ func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interfac
return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value)
}

func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
if numVal.Valid {
if numVal.NaN {
// set to nil if NaN
return nil, nil
}

switch numVal.InfinityModifier {
case pgtype.NegativeInfinity, pgtype.Infinity:
return nil, nil
}

rat := new(big.Rat).SetInt(numVal.Int)
if numVal.Exp > 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(numVal.Exp)), nil)
rat.Mul(rat, new(big.Rat).SetInt(mul))
} else if numVal.Exp < 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(-numVal.Exp)), nil)
rat.Quo(rat, new(big.Rat).SetInt(mul))
}
return rat, nil
func numericToDecimal(numVal pgtype.Numeric) (interface{}, error) {
switch {
case !numVal.Valid:
return nil, errors.New("invalid numeric")
case numVal.NaN, numVal.InfinityModifier == pgtype.Infinity,
numVal.InfinityModifier == pgtype.NegativeInfinity:
return nil, nil
default:
return decimal.NewFromBigInt(numVal.Int, numVal.Exp), nil
}

// handle invalid numeric
return nil, errors.New("invalid numeric")
}

func customTypeToQKind(typeName string) qvalue.QValueKind {
Expand Down
5 changes: 2 additions & 3 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package connsnowflake
import (
"context"
"fmt"
"math/big"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro"
Expand Down Expand Up @@ -36,8 +36,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
qvalue.QValueKindTimeTZ, qvalue.QValueKindDate:
value = time.Now()
case qvalue.QValueKindNumeric:
// create a new big.Rat for numeric data
value = big.NewRat(int64(placeHolder), 1)
value = decimal.New(int64(placeHolder), 1)
case qvalue.QValueKindUUID:
value = uuid.New() // assuming you have the github.com/google/uuid package
case qvalue.QValueKindQChar:
Expand Down
12 changes: 4 additions & 8 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"encoding/json"
"fmt"
"log/slog"
"math/big"
"strings"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"

Expand Down Expand Up @@ -212,7 +212,7 @@ func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Ro
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
values[i] = new([]byte)
case qvalue.QValueKindNumeric:
var s sql.NullString
var s sql.Null[decimal.Decimal]
values[i] = &s
case qvalue.QValueKindUUID:
values[i] = new([]byte)
Expand Down Expand Up @@ -380,13 +380,9 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
}
}
case qvalue.QValueKindNumeric:
if v, ok := val.(*sql.NullString); ok {
if v, ok := val.(*sql.Null[decimal.Decimal]); ok {
if v.Valid {
numeric := new(big.Rat)
if _, ok := numeric.SetString(v.String); !ok {
return qvalue.QValue{}, fmt.Errorf("failed to parse numeric: %v", v.String)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: numeric}, nil
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v.V}, nil
} else {
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/utils/cdc_records/cdc_records_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"errors"
"fmt"
"log/slog"
"math/big"
"os"
"runtime"
"sync/atomic"
"time"

"github.com/cockroachdb/pebble"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/log"

"github.com/PeerDB-io/peer-flow/model"
Expand Down Expand Up @@ -72,7 +72,7 @@ func (c *cdcRecordsStore) initPebbleDB() error {
gob.Register(&model.UpdateRecord{})
gob.Register(&model.DeleteRecord{})
gob.Register(time.Time{})
gob.Register(&big.Rat{})
gob.Register(decimal.Decimal{})

var err error
// we don't want a WAL since cache, we don't want to overwrite another DB either
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/utils/cdc_records/cdc_records_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package cdc_records
import (
"crypto/rand"
"log/slog"
"math/big"
"testing"
"time"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

"github.com/PeerDB-io/peer-flow/model"
Expand All @@ -27,9 +27,9 @@ func getTimeForTesting(t *testing.T) time.Time {
return tv
}

func getRatForTesting(t *testing.T) *big.Rat {
func getDecimalForTesting(t *testing.T) decimal.Decimal {
t.Helper()
return big.NewRat(123456789, 987654321)
return decimal.New(9876543210, 123)
}

func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
Expand All @@ -40,7 +40,7 @@ func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
require.NoError(t, err)

tv := getTimeForTesting(t)
rv := getRatForTesting(t)
rv := getDecimalForTesting(t)

key := model.TableWithPkey{
TableName: "test_src_tbl",
Expand Down Expand Up @@ -126,7 +126,7 @@ func TestRecordsTillSpill(t *testing.T) {
require.NoError(t, cdcRecordsStore.Close())
}

func TestTimeAndRatEncoding(t *testing.T) {
func TestTimeAndDecimalEncoding(t *testing.T) {
t.Parallel()

cdcRecordsStore := NewCDCRecordsStore("test_time_encoding")
Expand Down
30 changes: 12 additions & 18 deletions flow/e2e/bigquery/bigquery_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"cloud.google.com/go/bigquery"
"cloud.google.com/go/civil"
"github.com/shopspring/decimal"
"google.golang.org/api/iterator"

peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery"
Expand Down Expand Up @@ -227,7 +228,11 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) {
case time.Time:
return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil
case *big.Rat:
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil
val, err := decimal.NewFromString(v.FloatString(32))
if err != nil {
return qvalue.QValue{}, fmt.Errorf("bqHelper failed to parse as decimal %v", v)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: val}, nil
case []uint8:
return qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: v}, nil
case []bigquery.Value:
Expand Down Expand Up @@ -417,37 +422,26 @@ func (b *BigQueryTestHelper) CheckNull(tableName string, colName []string) (bool
}

// check if NaN, Inf double values are null
func (b *BigQueryTestHelper) CheckDoubleValues(tableName string, c1 string, c2 string) (bool, error) {
command := fmt.Sprintf("SELECT %s, %s FROM `%s.%s`",
c1, c2, b.Config.DatasetId, tableName)
func (b *BigQueryTestHelper) SelectRow(tableName string, cols ...string) ([]bigquery.Value, error) {
command := fmt.Sprintf("SELECT %s FROM `%s.%s`",
strings.Join(cols, ","), b.Config.DatasetId, tableName)
q := b.client.Query(command)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return false, fmt.Errorf("failed to run command: %w", err)
return nil, fmt.Errorf("failed to run command: %w", err)
}

var row []bigquery.Value
for {
err := it.Next(&row)
if err == iterator.Done {
break
return row, nil
}
if err != nil {
return false, fmt.Errorf("failed to iterate over query results: %w", err)
return nil, fmt.Errorf("failed to iterate over query results: %w", err)
}
}

if len(row) == 0 {
return false, nil
}

floatArr, _ := row[1].([]float64)
if row[0] != nil || len(floatArr) > 0 {
return false, nil
}

return true, nil
}

func qValueKindToBqColTypeString(val qvalue.QValueKind) (string, error) {
Expand Down
19 changes: 14 additions & 5 deletions flow/e2e/bigquery/peer_flow_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"cloud.google.com/go/bigquery"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -478,7 +479,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() {
srcTableName := s.attachSchemaSuffix("test_nans_bq")
dstTableName := "test_nans_bq"
_, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[]);
CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[],c3 numeric);
`, srcTableName))
require.NoError(s.t, err)

Expand All @@ -499,13 +500,21 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() {

// test inserting various types
_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}';
`, srcTableName))
INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}', 'NaN'::numeric;
`, srcTableName))
e2e.EnvNoError(s.t, env, err)

e2e.EnvWaitFor(s.t, env, 2*time.Minute, "normalize weird floats", func() bool {
good, err := s.bqHelper.CheckDoubleValues(dstTableName, "c1", "c2")
return err == nil && good
row, err := s.bqHelper.SelectRow(dstTableName, "c1", "c2", "c3")
if err != nil {
return false
}
if len(row) == 0 {
return false
}

floatArr, ok := row[1].([]bigquery.Value)
return ok && row[0] == nil && len(floatArr) == 0 && row[2] == nil
})

env.Cancel()
Expand Down
13 changes: 4 additions & 9 deletions flow/e2e/postgres/qrep_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,10 @@ func (s PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int)

func (s PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified, selector string) error {
// Execute the two EXCEPT queries
if err := s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector); err != nil {
return err
}
if err := s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector); err != nil {
return err
}

// If no error is returned, then the contents of the two tables are the same
return nil
return errors.Join(
s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector),
s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector),
)
}

func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualified string) error {
Expand Down
8 changes: 4 additions & 4 deletions flow/e2e/snowflake/snowflake_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"os"
"time"

"github.com/shopspring/decimal"

connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand Down Expand Up @@ -173,9 +174,8 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) {
case qvalue.QValueKindInt64:
return int(rec[0].Value.(int64)), nil
case qvalue.QValueKindNumeric:
// get big.Rat and convert to int
rat := rec[0].Value.(*big.Rat)
return int(rat.Num().Int64() / rat.Denom().Int64()), nil
val := rec[0].Value.(decimal.Decimal)
return int(val.IntPart()), nil
default:
return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind)
}
Expand Down
Loading

0 comments on commit 1110158

Please sign in to comment.