Skip to content

Commit

Permalink
Replace *big.Rat with shopspring decimal.Decimal
Browse files Browse the repository at this point in the history
Rational numbers are awkward decimals,
using a decimal type for decimal data has a few benefits:
1. perf-wise decimal representation is a *big.Int & an int32 exponent, instead of 2 *big.Ints
2. decimal has obvious conversion to/from decimal strings, whereas *big.Rat can be wonky going through FloatString
3. with scripting most people will want to use their decimal values as decimals, not decimal rationals

Fixes #1175
  • Loading branch information
serprex committed Mar 16, 2024
1 parent 78a8cdb commit a2cc30c
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 214 deletions.
4 changes: 2 additions & 2 deletions flow/connectors/postgres/qrep_query_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"bytes"
"context"
"fmt"
"math/big"
"testing"
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
)
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestAllDataTypes(t *testing.T) {
}

expectedNumeric := "123.456"
actualNumeric := record[10].Value.(*big.Rat).FloatString(3)
actualNumeric := record[10].Value.(decimal.Decimal).String()
if actualNumeric != expectedNumeric {
t.Fatalf("expected %v, got %v", expectedNumeric, actualNumeric)
}
Expand Down
46 changes: 14 additions & 32 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"net/netip"
"strings"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/lib/pq/oid"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
)

var big10 = big.NewInt(10)

func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValueKind {
switch recvOID {
case pgtype.BoolOID:
Expand Down Expand Up @@ -217,8 +215,7 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
val := qvalue.QValue{}

if value == nil {
val = qvalue.QValue{Kind: qvalueKind, Value: nil}
return val, nil
return qvalue.QValue{Kind: qvalueKind, Value: nil}, nil
}

switch qvalueKind {
Expand Down Expand Up @@ -341,11 +338,11 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindNumeric:
numVal := value.(pgtype.Numeric)
if numVal.Valid {
rat, err := numericToRat(&numVal)
num, err := numericToDecimal(numVal)
if err != nil {
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err)
return qvalue.QValue{}, fmt.Errorf("failed to convert numeric [%v] to decimal: %w", value, err)
}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat}
val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: num}
}
case qvalue.QValueKindArrayFloat32:
return convertToArray[float32](qvalueKind, value)
Expand Down Expand Up @@ -388,31 +385,16 @@ func (c *PostgresConnector) parseFieldFromPostgresOID(oid uint32, value interfac
return parseFieldFromQValueKind(c.postgresOIDToQValueKind(oid), value)
}

func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
if numVal.Valid {
if numVal.NaN {
// set to nil if NaN
return nil, nil
}

switch numVal.InfinityModifier {
case pgtype.NegativeInfinity, pgtype.Infinity:
return nil, nil
}

rat := new(big.Rat).SetInt(numVal.Int)
if numVal.Exp > 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(numVal.Exp)), nil)
rat.Mul(rat, new(big.Rat).SetInt(mul))
} else if numVal.Exp < 0 {
mul := new(big.Int).Exp(big10, big.NewInt(int64(-numVal.Exp)), nil)
rat.Quo(rat, new(big.Rat).SetInt(mul))
}
return rat, nil
func numericToDecimal(numVal pgtype.Numeric) (interface{}, error) {
switch {
case !numVal.Valid:
return nil, errors.New("invalid numeric")
case numVal.NaN, numVal.InfinityModifier == pgtype.Infinity,
numVal.InfinityModifier == pgtype.NegativeInfinity:
return nil, nil
default:
return decimal.NewFromBigInt(numVal.Int, numVal.Exp), nil
}

// handle invalid numeric
return nil, errors.New("invalid numeric")
}

func customTypeToQKind(typeName string) qvalue.QValueKind {
Expand Down
5 changes: 2 additions & 3 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package connsnowflake
import (
"context"
"fmt"
"math/big"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro"
Expand Down Expand Up @@ -36,8 +36,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
qvalue.QValueKindTimeTZ, qvalue.QValueKindDate:
value = time.Now()
case qvalue.QValueKindNumeric:
// create a new big.Rat for numeric data
value = big.NewRat(int64(placeHolder), 1)
value = decimal.New(int64(placeHolder), 1)
case qvalue.QValueKindUUID:
value = uuid.New() // assuming you have the github.com/google/uuid package
case qvalue.QValueKindQChar:
Expand Down
12 changes: 4 additions & 8 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"encoding/json"
"fmt"
"log/slog"
"math/big"
"strings"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jmoiron/sqlx"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"

Expand Down Expand Up @@ -212,7 +212,7 @@ func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Ro
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
values[i] = new([]byte)
case qvalue.QValueKindNumeric:
var s sql.NullString
var s sql.Null[decimal.Decimal]
values[i] = &s
case qvalue.QValueKindUUID:
values[i] = new([]byte)
Expand Down Expand Up @@ -380,13 +380,9 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
}
}
case qvalue.QValueKindNumeric:
if v, ok := val.(*sql.NullString); ok {
if v, ok := val.(*sql.Null[decimal.Decimal]); ok {
if v.Valid {
numeric := new(big.Rat)
if _, ok := numeric.SetString(v.String); !ok {
return qvalue.QValue{}, fmt.Errorf("failed to parse numeric: %v", v.String)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: numeric}, nil
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v.V}, nil
} else {
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/utils/cdc_records/cdc_records_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"errors"
"fmt"
"log/slog"
"math/big"
"os"
"runtime"
"sync/atomic"
"time"

"github.com/cockroachdb/pebble"
"github.com/shopspring/decimal"
"go.temporal.io/sdk/log"

"github.com/PeerDB-io/peer-flow/model"
Expand Down Expand Up @@ -72,7 +72,7 @@ func (c *cdcRecordsStore) initPebbleDB() error {
gob.Register(&model.UpdateRecord{})
gob.Register(&model.DeleteRecord{})
gob.Register(time.Time{})
gob.Register(&big.Rat{})
gob.Register(decimal.Decimal{})

var err error
// we don't want a WAL since cache, we don't want to overwrite another DB either
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/utils/cdc_records/cdc_records_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package cdc_records
import (
"crypto/rand"
"log/slog"
"math/big"
"testing"
"time"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

"github.com/PeerDB-io/peer-flow/model"
Expand All @@ -27,9 +27,9 @@ func getTimeForTesting(t *testing.T) time.Time {
return tv
}

func getRatForTesting(t *testing.T) *big.Rat {
func getDecimalForTesting(t *testing.T) decimal.Decimal {
t.Helper()
return big.NewRat(123456789, 987654321)
return decimal.New(9876543210, 123)
}

func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
Expand All @@ -40,7 +40,7 @@ func genKeyAndRec(t *testing.T) (model.TableWithPkey, model.Record) {
require.NoError(t, err)

tv := getTimeForTesting(t)
rv := getRatForTesting(t)
rv := getDecimalForTesting(t)

key := model.TableWithPkey{
TableName: "test_src_tbl",
Expand Down Expand Up @@ -126,7 +126,7 @@ func TestRecordsTillSpill(t *testing.T) {
require.NoError(t, cdcRecordsStore.Close())
}

func TestTimeAndRatEncoding(t *testing.T) {
func TestTimeAndDecimalEncoding(t *testing.T) {
t.Parallel()

cdcRecordsStore := NewCDCRecordsStore("test_time_encoding")
Expand Down
7 changes: 6 additions & 1 deletion flow/e2e/bigquery/bigquery_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"cloud.google.com/go/bigquery"
"cloud.google.com/go/civil"
"github.com/shopspring/decimal"
"google.golang.org/api/iterator"

peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery"
Expand Down Expand Up @@ -227,7 +228,11 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) {
case time.Time:
return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil
case *big.Rat:
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil
val, err := decimal.NewFromString(v.FloatString(32))
if err != nil {
return qvalue.QValue{}, fmt.Errorf("bqHelper failed to parse as decimal %v", v)
}
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: val}, nil
case []uint8:
return qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: v}, nil
case []bigquery.Value:
Expand Down
13 changes: 4 additions & 9 deletions flow/e2e/postgres/qrep_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,10 @@ func (s PeerFlowE2ETestSuitePG) setupSourceTable(tableName string, rowCount int)

func (s PeerFlowE2ETestSuitePG) comparePGTables(srcSchemaQualified, dstSchemaQualified, selector string) error {
// Execute the two EXCEPT queries
if err := s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector); err != nil {
return err
}
if err := s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector); err != nil {
return err
}

// If no error is returned, then the contents of the two tables are the same
return nil
return errors.Join(
s.compareQuery(srcSchemaQualified, dstSchemaQualified, selector),
s.compareQuery(dstSchemaQualified, srcSchemaQualified, selector),
)
}

func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualified string) error {
Expand Down
8 changes: 4 additions & 4 deletions flow/e2e/snowflake/snowflake_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"os"
"time"

"github.com/shopspring/decimal"

connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/e2eshared"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand Down Expand Up @@ -173,9 +174,8 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) {
case qvalue.QValueKindInt64:
return int(rec[0].Value.(int64)), nil
case qvalue.QValueKindNumeric:
// get big.Rat and convert to int
rat := rec[0].Value.(*big.Rat)
return int(rat.Num().Int64() / rat.Denom().Int64()), nil
val := rec[0].Value.(decimal.Decimal)
return int(val.IntPart()), nil
default:
return 0, fmt.Errorf("failed to execute query: %s, returned value of type %s", query, rec[0].Kind)
}
Expand Down
2 changes: 1 addition & 1 deletion flow/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ require (
github.com/linkedin/goavro/v2 v2.12.0
github.com/microsoft/go-mssqldb v1.7.0
github.com/orcaman/concurrent-map/v2 v2.0.1
github.com/shopspring/decimal v1.3.1
github.com/slack-go/slack v0.12.5
github.com/snowflakedb/gosnowflake v1.8.0
github.com/stretchr/testify v1.9.0
Expand Down Expand Up @@ -89,7 +90,6 @@ require (
github.com/prometheus/procfs v0.13.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
Expand Down
18 changes: 3 additions & 15 deletions flow/model/qrecord_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"errors"
"fmt"
"log/slog"
"math/big"
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/shopspring/decimal"

"github.com/PeerDB-io/peer-flow/geo"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -198,24 +198,12 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
values[i] = uuid.UUID(v)

case qvalue.QValueKindNumeric:
v, ok := qValue.Value.(*big.Rat)
v, ok := qValue.Value.(decimal.Decimal)
if !ok {
src.err = fmt.Errorf("invalid Numeric value %v", qValue.Value)
return nil, src.err
}
if v == nil {
values[i] = pgtype.Numeric{
Int: nil,
Exp: 0,
NaN: true,
InfinityModifier: pgtype.Finite,
Valid: true,
}
break
}

// TODO: account for precision and scale issues.
values[i] = v.FloatString(38)
values[i] = v.String()

case qvalue.QValueKindBytes, qvalue.QValueKindBit:
v, ok := qValue.Value.([]byte)
Expand Down
8 changes: 4 additions & 4 deletions flow/model/qrecord_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package model_test

import (
"math/big"
"testing"

"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"

"github.com/PeerDB-io/peer-flow/e2eshared"
Expand Down Expand Up @@ -35,14 +35,14 @@ func TestEquals(t *testing.T) {
},
{
name: "Equal - Same numeric",
q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: big.NewRat(10, 2)}},
q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}},
q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "5"}},
want: true,
},
{
name: "Not Equal - Different numeric",
q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: big.NewRat(10, 2)}},
q2: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: "4.99"}},
q1: []qvalue.QValue{{Kind: qvalue.QValueKindNumeric, Value: decimal.NewFromInt(5)}},
q2: []qvalue.QValue{{Kind: qvalue.QValueKindString, Value: "4.99"}},
want: false,
},
}
Expand Down
Loading

0 comments on commit a2cc30c

Please sign in to comment.