Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into less-noisy-type-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Feb 14, 2024
2 parents ed44388 + cc61dfb commit 5a538b1
Show file tree
Hide file tree
Showing 19 changed files with 210 additions and 30 deletions.
1 change: 1 addition & 0 deletions flow/connectors/clickhouse/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var clickhouseTypeToQValueKindMap = map[string]qvalue.QValueKind{
"CHAR": qvalue.QValueKindString,
"TEXT": qvalue.QValueKindString,
"String": qvalue.QValueKindString,
"FixedString(1)": qvalue.QValueKindQChar,
"Bool": qvalue.QValueKindBoolean,
"DateTime": qvalue.QValueKindTimestamp,
"TIMESTAMP": qvalue.QValueKindTimestamp,
Expand Down
9 changes: 9 additions & 0 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,15 @@ func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, forma
parsedData, err = dt.Codec.DecodeValue(p.typeMap, dataType, formatCode, data)
}
if err != nil {
if dt.Name == "time" || dt.Name == "timetz" ||
dt.Name == "timestamp" || dt.Name == "timestamptz" {
// indicates year is more than 4 digits or something similar,
// which you can insert into postgres,
// but not representable by time.Time
p.logger.Warn(fmt.Sprintf("Invalidated and hence nulled %s data: %s",
dt.Name, string(data)))
return qvalue.QValue{}, nil
}
return qvalue.QValue{}, err
}
retVal, err := p.parseFieldFromPostgresOID(dataType, parsedData)
Expand Down
17 changes: 13 additions & 4 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {

func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error {
if c.conn == nil {
return fmt.Errorf("check tables: conn is nil")
return errors.New("check tables: conn is nil")
}

// Check that we can select from all tables
Expand All @@ -649,11 +649,20 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []
}
}

// Check if tables belong to publication
tableStr := strings.Join(tableArr, ",")
if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err := c.conn.QueryRow(ctx, fmt.Sprintf(`
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
Expand All @@ -663,7 +672,7 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []
}

if pubTableCount != len(tableNames) {
return fmt.Errorf("not all tables belong to publication")
return errors.New("not all tables belong to publication")
}
}

Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ func (c *PostgresConnector) AddTablesToPublication(ctx context.Context, req *pro
// just check if we have all the tables already in the publication for custom publications
if req.PublicationName != "" {
rows, err := c.conn.Query(ctx,
"SELECT tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName)
"SELECT schemaname || '.' || tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName)
if err != nil {
return fmt.Errorf("failed to check tables in publication: %w", err)
}
Expand All @@ -986,7 +986,7 @@ func (c *PostgresConnector) AddTablesToPublication(ctx context.Context, req *pro
if err != nil {
return fmt.Errorf("failed to check tables in publication: %w", err)
}
notPresentTables := utils.ArrayMinus(tableNames, additionalSrcTables)
notPresentTables := utils.ArrayMinus(additionalSrcTables, tableNames)
if len(notPresentTables) > 0 {
return fmt.Errorf("some additional tables not present in custom publication: %s",
strings.Join(notPresentTables, ", "))
Expand Down
6 changes: 6 additions & 0 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValu
return qvalue.QValueKindFloat32
case pgtype.Float8OID:
return qvalue.QValueKindFloat64
case pgtype.QCharOID:
return qvalue.QValueKindQChar
case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID:
return qvalue.QValueKindString
case pgtype.ByteaOID:
Expand Down Expand Up @@ -125,6 +127,8 @@ func qValueKindToPostgresType(colTypeStr string) string {
return "REAL"
case qvalue.QValueKindFloat64:
return "DOUBLE PRECISION"
case qvalue.QValueKindQChar:
return "\"char\""
case qvalue.QValueKindString:
return "TEXT"
case qvalue.QValueKindBytes:
Expand Down Expand Up @@ -280,6 +284,8 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindFloat64:
floatVal := value.(float64)
val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal}
case qvalue.QValueKindQChar:
val = qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: uint8(value.(rune))}
case qvalue.QValueKindString:
// handling all unsupported types with strings as well for now.
val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: fmt.Sprint(value)}
Expand Down
12 changes: 9 additions & 3 deletions flow/connectors/postgres/schema_delta_test_constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var AddAllColumnTypes = []string{
string(qvalue.QValueKindJSON),
string(qvalue.QValueKindNumeric),
string(qvalue.QValueKindString),
string(qvalue.QValueKindQChar),
string(qvalue.QValueKindTime),
string(qvalue.QValueKindTimestamp),
string(qvalue.QValueKindTimestampTZ),
Expand Down Expand Up @@ -93,21 +94,26 @@ var AddAllColumnTypesFields = []*protos.FieldDescription{
},
{
Name: "c13",
Type: string(qvalue.QValueKindTime),
Type: string(qvalue.QValueKindQChar),
TypeModifier: -1,
},
{
Name: "c14",
Type: string(qvalue.QValueKindTimestamp),
Type: string(qvalue.QValueKindTime),
TypeModifier: -1,
},
{
Name: "c15",
Type: string(qvalue.QValueKindTimestampTZ),
Type: string(qvalue.QValueKindTimestamp),
TypeModifier: -1,
},
{
Name: "c16",
Type: string(qvalue.QValueKindTimestampTZ),
TypeModifier: -1,
},
{
Name: "c17",
Type: string(qvalue.QValueKindUUID),
TypeModifier: -1,
},
Expand Down
3 changes: 3 additions & 0 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
value = big.NewRat(int64(placeHolder), 1)
case qvalue.QValueKindUUID:
value = uuid.New() // assuming you have the github.com/google/uuid package
case qvalue.QValueKindQChar:
value = uint8(48)
// case qvalue.QValueKindArray:
// value = []int{1, 2, 3} // placeholder array, replace with actual logic
// case qvalue.QValueKindStruct:
Expand Down Expand Up @@ -85,6 +87,7 @@ func generateRecords(
qvalue.QValueKindNumeric,
qvalue.QValueKindBytes,
qvalue.QValueKindUUID,
qvalue.QValueKindQChar,
// qvalue.QValueKindJSON,
qvalue.QValueKindBit,
}
Expand Down
7 changes: 7 additions & 0 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ func (g *GenericSQLQueryExecutor) CheckNull(ctx context.Context, schema string,
}

func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
if val == nil {
return qvalue.QValue{Kind: kind, Value: nil}, nil
}
switch kind {
case qvalue.QValueKindInt32:
if v, ok := val.(*sql.NullInt32); ok {
Expand Down Expand Up @@ -341,6 +344,10 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
return qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil}, nil
}
}
case qvalue.QValueKindQChar:
if v, ok := val.(uint8); ok {
return qvalue.QValue{Kind: qvalue.QValueKindQChar, Value: v}, nil
}
case qvalue.QValueKindString:
if v, ok := val.(*sql.NullString); ok {
if v.Valid {
Expand Down
3 changes: 2 additions & 1 deletion flow/connectors/sqlserver/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var qValueKindToSQLServerTypeMap = map[qvalue.QValueKind]string{
qvalue.QValueKindFloat32: "REAL",
qvalue.QValueKindFloat64: "FLOAT",
qvalue.QValueKindNumeric: "DECIMAL(38, 9)",
qvalue.QValueKindQChar: "CHAR",
qvalue.QValueKindString: "NTEXT",
qvalue.QValueKindJSON: "NTEXT", // SQL Server doesn't have a native JSON type
qvalue.QValueKindTimestamp: "DATETIME2",
Expand Down Expand Up @@ -51,7 +52,7 @@ var sqlServerTypeToQValueKindMap = map[string]qvalue.QValueKind{
"UNIQUEIDENTIFIER": qvalue.QValueKindUUID,
"SMALLINT": qvalue.QValueKindInt32,
"TINYINT": qvalue.QValueKindInt32,
"CHAR": qvalue.QValueKindString,
"CHAR": qvalue.QValueKindQChar,
"VARCHAR": qvalue.QValueKindString,
"NCHAR": qvalue.QValueKindString,
"NVARCHAR": qvalue.QValueKindString,
Expand Down
1 change: 1 addition & 0 deletions flow/connectors/utils/avro/avro_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter
p.targetDWH,
p.avroSchema.NullableFields,
colNames,
logger,
)

avroMap, err := avroConverter.Convert()
Expand Down
2 changes: 1 addition & 1 deletion flow/e2e/bigquery/bigquery_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecor
}, nil
}

// returns whether the function errors or there are nulls
// returns whether the function errors or there are no nulls
func (b *BigQueryTestHelper) CheckNull(tableName string, colName []string) (bool, error) {
if len(colName) == 0 {
return true, nil
Expand Down
66 changes: 66 additions & 0 deletions flow/e2e/bigquery/qrep_flow_bq_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package e2e_bigquery

import (
"context"
"fmt"
"strings"

"github.com/stretchr/testify/require"

Expand All @@ -15,6 +17,34 @@ func (s PeerFlowE2ETestSuiteBQ) setupSourceTable(tableName string, rowCount int)
require.NoError(s.t, err)
}

func (s PeerFlowE2ETestSuiteBQ) setupTimeTable(tableName string) {
tblFields := []string{
"watermark_ts timestamp",
"mytimestamp timestamp",
"mytztimestamp timestamptz",
}
tblFieldStr := strings.Join(tblFields, ",")
_, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE e2e_test_%s.%s (
%s
);`, s.bqSuffix, tableName, tblFieldStr))

require.NoError(s.t, err)

var rows []string
row := `(CURRENT_TIMESTAMP,'10001-03-14 23:05:52','50001-03-14 23:05:52.216809+00')`
rows = append(rows, row)

_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`
INSERT INTO e2e_test_%s.%s (
watermark_ts,
mytimestamp,
mytztimestamp
) VALUES %s;
`, s.bqSuffix, tableName, strings.Join(rows, ",")))
require.NoError(s.t, err)
}

func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() {
env := e2e.NewTemporalTestWorkflowEnvironment(s.t)

Expand Down Expand Up @@ -46,6 +76,42 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Complete_QRep_Flow_Avro() {
e2e.RequireEqualTables(s, tblName, "*")
}

func (s PeerFlowE2ETestSuiteBQ) Test_Invalid_Timestamps_QRep() {
env := e2e.NewTemporalTestWorkflowEnvironment(s.t)

tblName := "test_qrep_flow_avro_bq"
s.setupTimeTable(tblName)

query := fmt.Sprintf("SELECT * FROM e2e_test_%s.%s WHERE watermark_ts BETWEEN {{.start}} AND {{.end}}",
s.bqSuffix, tblName)

qrepConfig, err := e2e.CreateQRepWorkflowConfig("test_qrep_flow_avro",
fmt.Sprintf("e2e_test_%s.%s", s.bqSuffix, tblName),
tblName,
query,
s.bqHelper.Peer,
"",
true,
"")
qrepConfig.WatermarkColumn = "watermark_ts"
require.NoError(s.t, err)
e2e.RunQrepFlowWorkflow(env, qrepConfig)

// Verify workflow completes without error
require.True(s.t, env.IsWorkflowCompleted())

err = env.GetWorkflowError()
require.NoError(s.t, err)

ok, err := s.bqHelper.CheckNull(tblName, []string{"mytimestamp"})
require.NoError(s.t, err)
require.False(s.t, ok)

ok, err = s.bqHelper.CheckNull(tblName, []string{"mytztimestamp"})
require.NoError(s.t, err)
require.False(s.t, ok)
}

func (s PeerFlowE2ETestSuiteBQ) Test_PeerDB_Columns_QRep_BQ() {
env := e2e.NewTemporalTestWorkflowEnvironment(s.t)

Expand Down
7 changes: 7 additions & 0 deletions flow/model/conversion_avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"fmt"

"go.temporal.io/sdk/log"

"github.com/PeerDB-io/peer-flow/model/qvalue"
)

Expand All @@ -12,19 +14,22 @@ type QRecordAvroConverter struct {
TargetDWH qvalue.QDWHType
NullableFields map[string]struct{}
ColNames []string
logger log.Logger
}

func NewQRecordAvroConverter(
q []qvalue.QValue,
targetDWH qvalue.QDWHType,
nullableFields map[string]struct{},
colNames []string,
logger log.Logger,
) *QRecordAvroConverter {
return &QRecordAvroConverter{
QRecord: q,
TargetDWH: targetDWH,
NullableFields: nullableFields,
ColNames: colNames,
logger: logger,
}
}

Expand All @@ -39,7 +44,9 @@ func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) {
val,
qac.TargetDWH,
nullable,
qac.logger,
)

avroVal, err := avroConverter.ToAvroValue()
if err != nil {
return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err)
Expand Down
6 changes: 6 additions & 0 deletions flow/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ func (r *RecordItems) toMap(hstoreAsJSON bool) (map[string]interface{}, error) {
}

jsonStruct[col] = binStr
case qvalue.QValueKindQChar:
ch, ok := v.Value.(uint8)
if !ok {
return nil, fmt.Errorf("expected \"char\" value for column %s for %T", col, v.Value)
}

jsonStruct[col] = string(ch)
case qvalue.QValueKindString, qvalue.QValueKindJSON:
strVal, ok := v.Value.(string)
if !ok {
Expand Down
8 changes: 8 additions & 0 deletions flow/model/qrecord_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
}
values[i] = v

case qvalue.QValueKindQChar:
v, ok := qValue.Value.(uint8)
if !ok {
src.err = fmt.Errorf("invalid \"char\" value")
return nil, src.err
}
values[i] = rune(v)

case qvalue.QValueKindString:
v, ok := qValue.Value.(string)
if !ok {
Expand Down
Loading

0 comments on commit 5a538b1

Please sign in to comment.