Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace *big.Rat with shopspring decimal.Decimal #1491

Merged
merged 2 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flow/connectors/postgres/qrep_query_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"bytes"
"context"
"fmt"
"math/big"
"testing"
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
)
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestAllDataTypes(t *testing.T) {
}

expectedNumeric := "123.456"
actualNumeric := record[10].Value.(*big.Rat).FloatString(3)
actualNumeric := record[10].Value.(decimal.Decimal).String()
if actualNumeric != expectedNumeric {
t.Fatalf("expected %v, got %v", expectedNumeric, actualNumeric)
}
Expand Down
46 changes: 14 additions & 32 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"net/netip"
"strings"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
)

var big10 = big.NewInt(10)

func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
switch recvOID {
case pgtype.BoolOID:
Expand Down Expand Up @@ -217,8 +215,7 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
val := qvalue.QValue{}

if value == nil {
val = qvalue.QValue{Kind: qvalueKind, Value: nil}
return val, nil
return qvalue.QValue{Kind: qvalueKind, Value: nil}, nil
}

switch qvalueKind {
Expand Down Expand Up @@ -341,11 +338,11 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindNumeric:
numVal := value.(pgtype.Numeric)
if numVal.Valid {
rat, err := numericToRat(&numVal)
num, err := numericToDecimal(numVal)
if err != nil {
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err)
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err)
}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: num}
}
case qvalue.QValueKindArrayFloat32:
return convertToArray[float32](qvalueKind, value)
Expand Down Expand Up @@ -388,31 +385,16 @@ func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interfac
return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value)
}

func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
if numVal.Valid {
if numVal.NaN {
// set to nil if NaN
return nil, nil
}

switch numVal.InfinityModifier {
case pgtype.NegativeInfinity, pgtype.Infinity:
return nil, nil
}

rat := new(big.Rat).SetInt(numVal.Int)
if numVal.Exp > 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(numVal.Exp)), nil)
rat.Mul(rat, new(big.Rat).SetInt(mul))
} else if numVal.Exp < 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(-numVal.Exp)), nil)
rat.Quo(rat, new(big.Rat).SetInt(mul))
}
return rat, nil
func numericToDecimal(numVal pgtype.Numeric) (interface{}, error) {
switch {
case !numVal.Valid:
return nil, errors.New("invalid numeric")
case numVal.NaN, numVal.InfinityModifier == pgtype.Infinity,
numVal.InfinityModifier == pgtype.NegativeInfinity:
return nil, nil
default:
return decimal.NewFromBigInt(numVal.Int, numVal.Exp), nil
}

// handle invalid numeric
return nil, errors.New("invalid numeric")
}

func customTypeToQKind(typeName string) qvalue.QValueKind {
Expand Down
5 changes: 2 additions & 3 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package connsnowflake
import (
"context"
"fmt"
"math/big"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro"
Expand Down Expand Up @@ -36,8 +36,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
qvalue.QValueKindTimeTZ, qvalue.QValueKindDate:
value = time.Now()
case qvalue.QValueKindNumeric:
// create a new big.Rat for numeric data
value = big.NewRat(int64(placeHolder), 1)
value = decimal.New(int64(placeHolder), 1)
case qvalue.QValueKindUUID:
value = uuid.New() // assuming you have the github.com/google/uuid package
case qvalue.QValueKindQChar:
Expand Down
12 changes: 4 additions & 8 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"encoding/json"
"fmt"
"log/slog"
"math/big"
"strings"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"

Expand Down Expand Up @@ -212,7 +212,7 @@ func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Ro
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
values[i] = new([]byte)
case qvalue.QValueKindNumeric:
var s sql.NullString
var s sql.Null[decimal.Decimal]
values[i] = &s
case qvalue.QValueKindUUID:
values[i] = new([]byte)
Expand Down Expand Up @@ -380,13 +380,9 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
}
}
case qvalue.QValueKindNumeric:
if v, ok := val.(*sql.NullString); ok {
if v, ok := val.(*sql.Null[decimal.Decimal]); ok {
if v.Valid {
numeric := new(big.Rat)
if _, ok := numeric.SetString(v.String); !ok {
return qvalue.QValue{}, fmt.Errorf("failed to parse numeric: %v", v.String)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: numeric}, nil
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v.V}, nil
} else {
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/utils/cdc_records/cdc_records_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"errors"
"fmt"
"log/slog"
"math/big"
"os"
"runtime"
"sync/atomic"
"time"

"github.com/cockroachdb/pebble"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/log"

"github.com/PeerDB-io/peer-flow/model"
Expand Down Expand Up @@ -72,7 +72,7 @@ func (c *cdcRecordsStore) initPebbleDB() error {
gob.Register(&model.UpdateRecord{})
gob.Register(&model.DeleteRecord{})
gob.Register(time.Time{})
gob.Register(&big.Rat{})
gob.Register(decimal.Decimal{})

var err error
// we don't want a WAL since cache, we don't want to overwrite another DB either
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/utils/cdc_records/cdc_records_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package cdc_records
import (
"crypto/rand"
"log/slog"
"math/big"
"testing"
"time"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

"github.com/PeerDB-io/peer-flow/model"
Expand All @@ -27,9 +27,9 @@ func getTimeForTesting(t *testing.T) time.Time {
return tv
}

func getRatForTesting(t *testing.T) *big.Rat {
func getDecimalForTesting(t *testing.T) decimal.Decimal {
t.Helper()
return big.NewRat(123456789, 987654321)
return decimal.New(9876543210, 123)
}

func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
Expand All @@ -40,7 +40,7 @@ func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
require.NoError(t, err)

tv := getTimeForTesting(t)
rv := getRatForTesting(t)
rv := getDecimalForTesting(t)

key := model.TableWithPkey{
TableName: "test_src_tbl",
Expand Down Expand Up @@ -126,7 +126,7 @@ func TestRecordsTillSpill(t *testing.T) {
require.NoError(t, cdcRecordsStore.Close())
}

func TestTimeAndRatEncoding(t *testing.T) {
func TestTimeAndDecimalEncoding(t *testing.T) {
t.Parallel()

cdcRecordsStore := NewCDCRecordsStore("test_time_encoding")
Expand Down
30 changes: 12 additions & 18 deletions flow/e2e/bigquery/bigquery_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"cloud.google.com/go/bigquery"
"cloud.google.com/go/civil"
"github.com/shopspring/decimal"
"google.golang.org/api/iterator"

peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery"
Expand Down Expand Up @@ -227,7 +228,11 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) {
case time.Time:
return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil
case *big.Rat:
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil
val, err := decimal.NewFromString(v.FloatString(32))
if err != nil {
return qvalue.QValue{}, fmt.Errorf("bqHelper failed to parse as decimal %v", v)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: val}, nil
case []uint8:
return qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: v}, nil
case []bigquery.Value:
Expand Down Expand Up @@ -417,37 +422,26 @@ func (b *BigQueryTestHelper) CheckNull(tableName string, colName []string) (bool
}

// check if NaN, Inf double values are null
func (b *BigQueryTestHelper) CheckDoubleValues(tableName string, c1 string, c2 string) (bool, error) {
command := fmt.Sprintf("SELECT %s, %s FROM `%s.%s`",
c1, c2, b.Config.DatasetId, tableName)
func (b *BigQueryTestHelper) SelectRow(tableName string, cols ...string) ([]bigquery.Value, error) {
command := fmt.Sprintf("SELECT %s FROM `%s.%s`",
strings.Join(cols, ","), b.Config.DatasetId, tableName)
q := b.client.Query(command)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return false, fmt.Errorf("failed to run command: %w", err)
return nil, fmt.Errorf("failed to run command: %w", err)
}

var row []bigquery.Value
for {
err := it.Next(&row)
if err == iterator.Done {
break
return row, nil
}
if err != nil {
return false, fmt.Errorf("failed to iterate over query results: %w", err)
return nil, fmt.Errorf("failed to iterate over query results: %w", err)
}
}

if len(row) == 0 {
return false, nil
}

floatArr, _ := row[1].([]float64)
if row[0] != nil || len(floatArr) > 0 {
return false, nil
}

return true, nil
}

func qValueKindToBqColTypeString(val qvalue.QValueKind) (string, error) {
Expand Down
19 changes: 14 additions & 5 deletions flow/e2e/bigquery/peer_flow_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"cloud.google.com/go/bigquery"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -478,7 +479,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() {
srcTableName := s.attachSchemaSuffix("test_nans_bq")
dstTableName := "test_nans_bq"
_, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[]);
CREATE TABLE IF NOT EXISTS %s (id serial PRIMARY KEY,c1 double precision,c2 double precision[],c3 numeric);
`, srcTableName))
require.NoError(s.t, err)

Expand All @@ -499,13 +500,21 @@ func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() {

// test inserting various types
_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}';
`, srcTableName))
INSERT INTO %s SELECT 2, 'NaN'::double precision, '{NaN, Infinity, -Infinity}', 'NaN'::numeric;
`, srcTableName))
e2e.EnvNoError(s.t, env, err)

e2e.EnvWaitFor(s.t, env, 2*time.Minute, "normalize weird floats", func() bool {
good, err := s.bqHelper.CheckDoubleValues(dstTableName, "c1", "c2")
return err == nil && good
row, err := s.bqHelper.SelectRow(dstTableName, "c1", "c2", "c3")
if err != nil {
return false
}
if len(row) == 0 {
return false
}

floatArr, ok := row[1].([]bigquery.Value)
return ok && row[0] == nil && len(floatArr) == 0 && row[2] == nil
})

env.Cancel()
Expand Down
13 changes: 4 additions & 9 deletions flow/e2e/postgres/qrep_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,10 @@ func (s PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int)

func (s PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified, selector string) error {
// Execute the two EXCEPT queries
if err := s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector); err != nil {
return err
}
if err := s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector); err != nil {
return err
}

// If no error is returned, then the contents of the two tables are the same
return nil
return errors.Join(
s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector),
s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector),
)
}

func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualified string) error {
Expand Down
8 changes: 4 additions & 4 deletions flow/e2e/snowflake/snowflake_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"os"
"time"

"github.com/shopspring/decimal"

connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand Down Expand Up @@ -173,9 +174,8 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) {
case qvalue.QValueKindInt64:
return int(rec[0].Value.(int64)), nil
case qvalue.QValueKindNumeric:
// get big.Rat and convert to int
rat := rec[0].Value.(*big.Rat)
return int(rat.Num().Int64() / rat.Denom().Int64()), nil
val := rec[0].Value.(decimal.Decimal)
return int(val.IntPart()), nil
default:
return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind)
}
Expand Down
Loading
Loading