diff --git a/.github/workflows/flow.yml b/.github/workflows/flow.yml index cdffa40c5f..4dd58493c6 100644 --- a/.github/workflows/flow.yml +++ b/.github/workflows/flow.yml @@ -48,9 +48,17 @@ jobs: name: "bq_service_account.json" json: ${{ secrets.GCP_GH_CI_PKEY }} + - name: setup snowflake credentials + id: sf-credentials + uses: jsdaniell/create-json@v1.2.2 + with: + name: "snowflake_creds.json" + json: ${{ secrets.SNOWFLAKE_GH_CI_PKEY }} + - name: run tests run: | gotestsum --format testname working-directory: ./flow env: TEST_BQ_CREDS: ${{ github.workspace }}/bq_service_account.json + TEST_SF_CREDS: ${{ github.workspace }}/snowflake_creds.json diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index e848c78c8c..641d9f9455 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -65,7 +65,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( // Write each QRecord to the OCF file for _, qRecord := range records.Records { - avroMap, err := qRecord.ToAvroCompatibleMap(&nullable, records.ColumnNames) + avroMap, err := qRecord.ToAvroCompatibleMap(model.QDBTypeBigQuery, &nullable, records.Schema.GetColumnNames()) if err != nil { return 0, fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err) } @@ -205,7 +205,7 @@ func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { case bigquery.TimestampFieldType: return map[string]string{ "type": "long", - "logicalType": "timestamp-millis", + "logicalType": "timestamp-micros", }, nil case bigquery.DateFieldType: return map[string]string{ diff --git a/flow/connectors/bigquery/qrep_sync_method.go b/flow/connectors/bigquery/qrep_sync_method.go index e207c91b05..d063e4c2c6 100644 --- a/flow/connectors/bigquery/qrep_sync_method.go +++ b/flow/connectors/bigquery/qrep_sync_method.go @@ -72,7 +72,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( numRowsInserted := 0 for _, qRecord := range records.Records { toPut := QRecordValueSaver{ - ColumnNames: records.ColumnNames, + ColumnNames: records.Schema.GetColumnNames(), Record: qRecord, PartitionID: partitionID, RunID: runID, diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 4df238f6b4..d714f389cd 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -36,6 +36,55 @@ func (qe *QRepQueryExecutor) ExecuteQuery(query string, args ...interface{}) (pg return rows, nil } +func fieldDescriptionToQValueKind(fd pgconn.FieldDescription) model.QValueKind { + switch fd.DataTypeOID { + case pgtype.BoolOID: + return model.QValueKindBoolean + case pgtype.Int2OID: + return model.QValueKindInt16 + case pgtype.Int4OID: + return model.QValueKindInt32 + case pgtype.Int8OID: + return model.QValueKindInt64 + case pgtype.Float4OID: + return model.QValueKindFloat32 + case pgtype.Float8OID: + return model.QValueKindFloat64 + case pgtype.TextOID, pgtype.VarcharOID: + return model.QValueKindString + case pgtype.ByteaOID: + return model.QValueKindBytes + case pgtype.JSONOID, pgtype.JSONBOID: + return model.QValueKindJSON + case pgtype.UUIDOID: + return model.QValueKindUUID + case pgtype.TimestampOID, pgtype.TimestamptzOID, pgtype.DateOID, pgtype.TimeOID: + return model.QValueKindETime + case pgtype.NumericOID: + return model.QValueKindNumeric + default: + return model.QValueKindInvalid + } +} + +// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. +func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { + qfields := make([]*model.QField, len(fds)) + for i, fd := range fds { + cname := fd.Name + ctype := fieldDescriptionToQValueKind(fd) + // there isn't a way to know if a column is nullable or not + // TODO fix this. + cnullable := true + qfields[i] = &model.QField{ + Name: cname, + Type: ctype, + Nullable: cnullable, + } + } + return model.NewQRecordSchema(qfields) +} + func (qe *QRepQueryExecutor) ProcessRows( rows pgx.Rows, fieldDescriptions []pgconn.FieldDescription, @@ -57,16 +106,10 @@ func (qe *QRepQueryExecutor) ProcessRows( return nil, fmt.Errorf("row iteration failed: %w", rows.Err()) } - // get col names from fieldDescriptions - colNames := make([]string, len(fieldDescriptions)) - for i, fd := range fieldDescriptions { - colNames[i] = fd.Name - } - batch := &model.QRecordBatch{ - NumRecords: uint32(len(records)), - Records: records, - ColumnNames: colNames, + NumRecords: uint32(len(records)), + Records: records, + Schema: fieldDescriptionsToSchema(fieldDescriptions), } return batch, nil diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go new file mode 100644 index 0000000000..d07f0a3872 --- /dev/null +++ b/flow/connectors/snowflake/client.go @@ -0,0 +1,394 @@ +package connsnowflake + +import ( + "context" + "database/sql" + "fmt" + "math/big" + "strings" + "time" + + "github.com/jmoiron/sqlx" + log "github.com/sirupsen/logrus" + "github.com/snowflakedb/gosnowflake" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + util "github.com/PeerDB-io/peer-flow/utils" +) + +type SnowflakeClient struct { + // ctx is the context. + ctx context.Context + // config is the Snowflake config. + Config *protos.SnowflakeConfig + // connection to Snowflake + conn *sqlx.DB +} + +func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*SnowflakeClient, error) { + privateKey, err := util.DecodePKCS8PrivateKey([]byte(config.PrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %w", err) + } + + snowflakeConfig := gosnowflake.Config{ + Account: config.AccountId, + User: config.Username, + Authenticator: gosnowflake.AuthTypeJwt, + PrivateKey: privateKey, + Database: config.Database, + Warehouse: config.Warehouse, + Role: config.Role, + RequestTimeout: time.Duration(config.QueryTimeout) * time.Second, + DisableTelemetry: true, + } + + snowflakeConfigDSN, err := gosnowflake.DSN(&snowflakeConfig) + if err != nil { + return nil, fmt.Errorf("failed to get DSN from Snowflake config: %w", err) + } + + database, err := sqlx.Open("snowflake", snowflakeConfigDSN) + if err != nil { + return nil, fmt.Errorf("failed to open connection to Snowflake peer: %w", err) + } + + err = database.PingContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to open connection to Snowflake peer: %w", err) + } + + return &SnowflakeClient{ + ctx: ctx, + Config: config, + conn: database, + }, nil +} + +// ConnectionActive checks if the connection is active. +func (s *SnowflakeClient) ConnectionActive() bool { + return s.conn.PingContext(s.ctx) == nil +} + +// add a Close() method to SnowflakeClient +func (s *SnowflakeClient) Close() error { + return s.conn.Close() +} + +// schemaExists checks if the schema exists. +func (s *SnowflakeClient) schemaExists(schema string) (bool, error) { + var exists bool + query := fmt.Sprintf("SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", schema) + err := s.conn.QueryRowContext(s.ctx, query).Scan(&exists) + if err != nil { + return false, fmt.Errorf("failed to query schema: %w", err) + } + + return exists, nil +} + +// RecreateSchema recreates the schema, i.e., drops it if exists and creates it again. +func (s *SnowflakeClient) RecreateSchema(schema string) error { + exists, err := s.schemaExists(schema) + if err != nil { + return fmt.Errorf("failed to check if schema %s exists: %w", schema, err) + } + + if exists { + stmt := fmt.Sprintf("DROP SCHEMA %s", schema) + _, err := s.conn.ExecContext(s.ctx, stmt) + if err != nil { + return fmt.Errorf("failed to drop schema: %w", err) + } + } + + stmt := fmt.Sprintf("CREATE SCHEMA %s", schema) + _, err = s.conn.ExecContext(s.ctx, stmt) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + fmt.Printf("created schema %s successfully\n", schema) + return nil +} + +// DropSchema drops the schema. +func (s *SnowflakeClient) DropSchema(schema string) error { + exists, err := s.schemaExists(schema) + if err != nil { + return fmt.Errorf("failed to check if schema %s exists: %w", schema, err) + } + + if exists { + stmt := fmt.Sprintf("DROP SCHEMA %s", schema) + _, err := s.conn.ExecContext(s.ctx, stmt) + if err != nil { + return fmt.Errorf("failed to drop schema: %w", err) + } + } + + return nil +} + +// RunCommand runs the given command. +func (s *SnowflakeClient) RunCommand(command string) error { + _, err := s.conn.ExecContext(s.ctx, command) + if err != nil { + return fmt.Errorf("failed to run command: %w", err) + } + + return nil +} + +// CountRows returns the number of rows in the given table. +func (s *SnowflakeClient) CountRows(schema string, tableName string) (int, error) { + var count int + query := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s", schema, tableName) + err := s.conn.GetContext(s.ctx, &count, query) + if err != nil { + return 0, fmt.Errorf("failed to run command: %w", err) + } + + return count, nil +} + +func toQValue(kind model.QValueKind, val interface{}) (model.QValue, error) { + switch kind { + case model.QValueKindInt32: + if v, ok := val.(*int); ok && v != nil { + return model.QValue{Kind: model.QValueKindInt32, Value: *v}, nil + } + case model.QValueKindInt64: + if v, ok := val.(*int64); ok && v != nil { + return model.QValue{Kind: model.QValueKindInt64, Value: *v}, nil + } + case model.QValueKindFloat32: + if v, ok := val.(*float32); ok && v != nil { + return model.QValue{Kind: model.QValueKindFloat32, Value: *v}, nil + } + case model.QValueKindFloat64: + if v, ok := val.(*float64); ok && v != nil { + return model.QValue{Kind: model.QValueKindFloat64, Value: *v}, nil + } + case model.QValueKindString: + if v, ok := val.(*string); ok && v != nil { + return model.QValue{Kind: model.QValueKindString, Value: *v}, nil + } + case model.QValueKindBoolean: + if v, ok := val.(*bool); ok && v != nil { + return model.QValue{Kind: model.QValueKindBoolean, Value: *v}, nil + } + case model.QValueKindNumeric: + // convert string to big.Rat + if v, ok := val.(*string); ok && v != nil { + //nolint:gosec + ratVal, ok := new(big.Rat).SetString(*v) + if !ok { + return model.QValue{}, fmt.Errorf("failed to convert string to big.Rat: %s", *v) + } + return model.QValue{ + Kind: model.QValueKindNumeric, + Value: ratVal, + }, nil + } + case model.QValueKindETime: + if v, ok := val.(*time.Time); ok && v != nil { + etimeVal, err := model.NewExtendedTime(*v, model.DateTimeKindType, "") + if err != nil { + return model.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) + } + return model.QValue{ + Kind: model.QValueKindETime, + Value: etimeVal, + }, nil + } + case model.QValueKindBytes: + if v, ok := val.(*[]byte); ok && v != nil { + return model.QValue{Kind: model.QValueKindBytes, Value: *v}, nil + } + } + + // If type is unsupported or doesn't match the specified kind, return error + return model.QValue{}, fmt.Errorf("[snowflakeclient] unsupported type %T for kind %s", val, kind) +} + +// databaseTypeNameToQValueKind converts a database type name to a QValueKind. +func databaseTypeNameToQValueKind(name string) (model.QValueKind, error) { + switch name { + case "INT": + return model.QValueKindInt32, nil + case "BIGINT": + return model.QValueKindInt64, nil + case "FLOAT": + return model.QValueKindFloat32, nil + case "DOUBLE", "REAL": + return model.QValueKindFloat64, nil + case "VARCHAR", "CHAR", "TEXT": + return model.QValueKindString, nil + case "BOOLEAN": + return model.QValueKindBoolean, nil + case "DATETIME", "TIMESTAMP", "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "TIMESTAMP_TZ": + return model.QValueKindETime, nil + case "BLOB", "BYTEA", "BINARY": + return model.QValueKindBytes, nil + case "FIXED", "NUMBER": + return model.QValueKindNumeric, nil + default: + // If type is unsupported, return an error + return "", fmt.Errorf("unsupported database type name: %s", name) + } +} + +func columnTypeToQField(ct *sql.ColumnType) (*model.QField, error) { + qvKind, err := databaseTypeNameToQValueKind(ct.DatabaseTypeName()) + if err != nil { + return nil, err + } + + nullable, ok := ct.Nullable() + + return &model.QField{ + Name: ct.Name(), + Type: qvKind, + Nullable: ok && nullable, + }, nil +} + +func (s *SnowflakeClient) ExecuteAndProcessQuery(query string) (*model.QRecordBatch, error) { + rows, err := s.conn.QueryContext(s.ctx, query) + if err != nil { + fmt.Printf("failed to run command: %v\n", err) + return nil, fmt.Errorf("failed to run command: %w", err) + } + defer rows.Close() + + dbColTypes, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + // Convert dbColTypes to QFields + qfields := make([]*model.QField, len(dbColTypes)) + for i, ct := range dbColTypes { + qfield, err := columnTypeToQField(ct) + if err != nil { + log.Errorf("failed to convert column type %v: %v", ct, err) + return nil, err + } + qfields[i] = qfield + } + + var records []*model.QRecord + + for rows.Next() { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + switch qfields[i].Type { + case model.QValueKindETime: + values[i] = new(time.Time) + case model.QValueKindInt16: + values[i] = new(int16) + case model.QValueKindInt32: + values[i] = new(int32) + case model.QValueKindInt64: + values[i] = new(int64) + case model.QValueKindFloat32: + values[i] = new(float32) + case model.QValueKindFloat64: + values[i] = new(float64) + case model.QValueKindBoolean: + values[i] = new(bool) + case model.QValueKindString: + values[i] = new(string) + case model.QValueKindBytes: + values[i] = new([]byte) + case model.QValueKindNumeric: + values[i] = new(string) + default: + values[i] = new(interface{}) + } + } + + if err := rows.Scan(values...); err != nil { + return nil, err + } + + qValues := make([]model.QValue, len(values)) + for i, val := range values { + qv, err := toQValue(qfields[i].Type, val) + if err != nil { + log.Errorf("failed to convert value: %v", err) + return nil, err + } + qValues[i] = qv + } + + // Create a QRecord + record := model.NewQRecord(len(qValues)) + for i, qv := range qValues { + record.Set(i, qv) + } + + records = append(records, record) + } + + if err := rows.Err(); err != nil { + log.Errorf("failed to iterate over rows: %v", err) + return nil, err + } + + // Return a QRecordBatch + return &model.QRecordBatch{ + NumRecords: uint32(len(records)), + Records: records, + Schema: model.NewQRecordSchema(qfields), + }, nil +} + +func (s *SnowflakeClient) CreateTable(schema *model.QRecordSchema, schemaName string, tableName string) error { + var fields []string + for _, field := range schema.Fields { + snowflakeType, err := qValueKindToSnowflakeColTypeString(field.Type) + if err != nil { + return err + } + fields = append(fields, fmt.Sprintf("%s %s", field.Name, snowflakeType)) + } + + command := fmt.Sprintf("CREATE TABLE %s.%s (%s)", schemaName, tableName, strings.Join(fields, ", ")) + fmt.Printf("creating table %s.%s with command %s\n", schemaName, tableName, command) + + _, err := s.conn.ExecContext(s.ctx, command) + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + return nil +} + +func qValueKindToSnowflakeColTypeString(val model.QValueKind) (string, error) { + switch val { + case model.QValueKindInt32, model.QValueKindInt64: + return "INT", nil + case model.QValueKindFloat32, model.QValueKindFloat64: + return "FLOAT", nil + case model.QValueKindString: + return "STRING", nil + case model.QValueKindBoolean: + return "BOOLEAN", nil + case model.QValueKindETime: + return "TIMESTAMP_LTZ", nil + case model.QValueKindBytes: + return "BINARY", nil + case model.QValueKindNumeric: + return "NUMERIC(38,32)", nil + default: + return "", fmt.Errorf("unsupported QValueKind: %v", val) + } +} diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index 367c301025..c4fb25b058 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -1,10 +1,19 @@ package connsnowflake import ( + "database/sql" + "fmt" + "os" + "time" + "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protojson" ) +const qRepMetadataTableName = "_peerdb_query_replication_metadata" + func (c *SnowflakeConnector) GetQRepPartitions(config *protos.QRepConfig, last *protos.QRepPartition, ) ([]*protos.QRepPartition, error) { @@ -22,9 +31,124 @@ func (c *SnowflakeConnector) SyncQRepRecords( partition *protos.QRepPartition, records *model.QRecordBatch, ) (int, error) { - panic("not implemented") + // Ensure the destination table is available. + destTable := config.DestinationTableIdentifier + + tblSchema, err := c.getTableSchema(destTable) + if err != nil { + return 0, fmt.Errorf("failed to get schema of table %s: %w", destTable, err) + } + + done, err := c.isPartitionSynced(partition.PartitionId) + if err != nil { + return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err) + } + + if done { + log.Infof("Partition %s has already been synced", partition.PartitionId) + return 0, nil + } + + syncMode := config.SyncMode + switch syncMode { + case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT: + return 0, fmt.Errorf("multi-insert sync mode not supported for snowflake") + case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO: + // create a temp directory for storing avro files + tmpDir, err := os.MkdirTemp("", "peerdb-avro") + if err != nil { + return 0, fmt.Errorf("failed to create temp directory: %w", err) + } + avroSync := &SnowflakeAvroSyncMethod{connector: c, localDir: tmpDir} + return avroSync.SyncQRepRecords(config.FlowJobName, destTable, partition, tblSchema, records) + default: + return 0, fmt.Errorf("unsupported sync mode: %s", syncMode) + } +} + +func (c *SnowflakeConnector) createMetadataInsertStatement( + partition *protos.QRepPartition, + jobName string, + startTime time.Time, +) (string, error) { + // marshal the partition to json using protojson + pbytes, err := protojson.Marshal(partition) + if err != nil { + return "", fmt.Errorf("failed to marshal partition to json: %v", err) + } + + // convert the bytes to string + partitionJSON := string(pbytes) + + insertMetadataStmt := fmt.Sprintf( + `INSERT INTO %s.%s + (flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) + VALUES ('%s', '%s', '%s', '%s'::timestamp, CURRENT_TIMESTAMP);`, + "public", qRepMetadataTableName, jobName, partition.PartitionId, + partitionJSON, startTime.Format(time.RFC3339)) + + return insertMetadataStmt, nil +} + +func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) { + //nolint:gosec + queryString := fmt.Sprintf(` + SELECT * + FROM %s + LIMIT 0 + `, tableName) + + rows, err := c.database.Query(queryString) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + defer rows.Close() + + columnTypes, err := rows.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("failed to get column types: %w", err) + } + + return columnTypes, nil +} + +func (c *SnowflakeConnector) isPartitionSynced(partitionID string) (bool, error) { + //nolint:gosec + queryString := fmt.Sprintf(` + SELECT COUNT(*) + FROM _peerdb_query_replication_metadata + WHERE partitionID = '%s' + `, partitionID) + + row := c.database.QueryRow(queryString) + + var count int + if err := row.Scan(&count); err != nil { + return false, fmt.Errorf("failed to execute query: %w", err) + } + + return count > 0, nil } func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { - panic("SetupQRepMetadataTables not implemented for snowflake connector") + // Define the schema + schemaStatement := ` + CREATE TABLE IF NOT EXISTS %s.%s ( + flowJobName STRING, + partitionID STRING, + syncPartition STRING, + syncStartTime TIMESTAMP_LTZ, + syncFinishTime TIMESTAMP_LTZ + ); + ` + queryString := fmt.Sprintf(schemaStatement, "public", qRepMetadataTableName) + + // Execute the query + _, err := c.database.Exec(queryString) + if err != nil { + return fmt.Errorf("failed to create table %s.%s: %w", "public", qRepMetadataTableName, err) + } + + log.Infof("Created table %s", qRepMetadataTableName) + return nil } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go new file mode 100644 index 0000000000..cb3fdb695b --- /dev/null +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -0,0 +1,127 @@ +package connsnowflake + +import ( + "database/sql" + "fmt" + "os" + "strings" + "time" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + util "github.com/PeerDB-io/peer-flow/utils" + "github.com/linkedin/goavro/v2" + log "github.com/sirupsen/logrus" + _ "github.com/snowflakedb/gosnowflake" +) + +type SnowflakeAvroSyncMethod struct { + connector *SnowflakeConnector + localDir string +} + +func NewSnowflakeAvroSyncMethod(connector *SnowflakeConnector, localDir string) *SnowflakeAvroSyncMethod { + return &SnowflakeAvroSyncMethod{ + connector: connector, + localDir: localDir, + } +} + +func (s *SnowflakeAvroSyncMethod) SyncQRepRecords( + flowJobName string, + dstTableName string, + partition *protos.QRepPartition, + dstTableSchema []*sql.ColumnType, + records *model.QRecordBatch) (int, error) { + startTime := time.Now() + + // You will need to define your Avro schema as a string + avroSchema, err := model.GetAvroSchemaDefinition(dstTableName, records.Schema) + if err != nil { + return 0, fmt.Errorf("failed to define Avro schema: %w", err) + } + + fmt.Printf("Avro schema: %v\n", avroSchema) + + // Create a local file path with flowJobName and partitionID + localFilePath := fmt.Sprintf("%s/%s.avro", s.localDir, partition.PartitionId) + file, err := os.Create(localFilePath) + if err != nil { + return 0, fmt.Errorf("failed to create file: %w", err) + } + defer file.Close() + + // Create OCF Writer + ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{ + W: file, + Schema: avroSchema.Schema, + }) + if err != nil { + return 0, fmt.Errorf("failed to create OCF writer: %w", err) + } + + colNames := records.Schema.GetColumnNames() + + // Write each QRecord to the OCF file + for _, qRecord := range records.Records { + avroMap, err := qRecord.ToAvroCompatibleMap(model.QDBTypeSnowflake, &avroSchema.NullableFields, colNames) + if err != nil { + log.Errorf("failed to convert QRecord to Avro compatible map: %v", err) + return 0, fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err) + } + + err = ocfWriter.Append([]interface{}{avroMap}) + if err != nil { + log.Errorf("failed to write record to OCF file: %v", err) + return 0, fmt.Errorf("failed to write record to OCF file: %w", err) + } + } + + // this runID is just used for the staging table name + runID, err := util.RandomUInt64() + if err != nil { + return 0, fmt.Errorf("failed to generate run ID: %w", err) + } + + // create temp stag + stage := fmt.Sprintf("%s_%d", dstTableName, runID) + createStageCmd := fmt.Sprintf("CREATE TEMPORARY STAGE %s FILE_FORMAT = (TYPE = AVRO)", stage) + if _, err = s.connector.database.Exec(createStageCmd); err != nil { + return 0, fmt.Errorf("failed to create temp stage: %w", err) + } + log.Infof("created temp stage %s", stage) + + // Put the local Avro file to the Snowflake stage + putCmd := fmt.Sprintf("PUT file://%s @%s", localFilePath, stage) + if _, err = s.connector.database.Exec(putCmd); err != nil { + return 0, fmt.Errorf("failed to put file to stage: %w", err) + } + log.Infof("put file %s to stage %s", localFilePath, stage) + + // write this file to snowflake using COPY INTO statement + copyOpts := []string{ + "FILE_FORMAT = (TYPE = AVRO)", + "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'", + } + //nolint:gosec + copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", dstTableName, stage, strings.Join(copyOpts, ",")) + if _, err = s.connector.database.Exec(copyCmd); err != nil { + return 0, fmt.Errorf("failed to run COPY INTO command: %w", err) + } + log.Infof("copied file from stage %s to table %s", stage, dstTableName) + + // Insert metadata statement + insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) + if err != nil { + return -1, fmt.Errorf("failed to create metadata insert statement: %v", err) + } + + // Execute the metadata insert statement + if _, err = s.connector.database.Exec(insertMetadataStmt); err != nil { + return -1, fmt.Errorf("failed to execute metadata insert statement: %v", err) + } + + log.Infof("pushed %d records to local file %s and loaded into Snowflake table %s", + len(records.Records), localFilePath, dstTableName) + return len(records.Records), nil +} diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 02ab58af8d..d38601883a 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -2,11 +2,8 @@ package connsnowflake import ( "context" - "crypto/rsa" - "crypto/x509" "database/sql" "encoding/json" - "encoding/pem" "fmt" "regexp" "strings" @@ -14,6 +11,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + util "github.com/PeerDB-io/peer-flow/utils" "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/snowflakedb/gosnowflake" @@ -70,10 +68,12 @@ type tableNameComponents struct { schemaIdentifier string tableIdentifier string } + type SnowflakeConnector struct { ctx context.Context database *sql.DB tableSchemaMapping map[string]*protos.TableSchema + client *SnowflakeClient } type snowflakeRawRecord struct { @@ -87,28 +87,9 @@ type snowflakeRawRecord struct { items map[string]interface{} } -// reads the PKCS8 private key from the received config and converts it into something that gosnowflake wants. -func readPKCS8PrivateKey(rawKey []byte) (*rsa.PrivateKey, error) { - // pem.Decode has weird return values, no err as such - PEMBlock, _ := pem.Decode(rawKey) - if PEMBlock == nil { - return nil, fmt.Errorf("failed to decode private key PEM block") - } - privateKeyAny, err := x509.ParsePKCS8PrivateKey(PEMBlock.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse private key PEM block as PKCS8: %w", err) - } - privateKeyRSA, ok := privateKeyAny.(*rsa.PrivateKey) - if !ok { - return nil, fmt.Errorf("key does not appear to RSA as expected") - } - - return privateKeyRSA, nil -} - func NewSnowflakeConnector(ctx context.Context, snowflakeProtoConfig *protos.SnowflakeConfig) (*SnowflakeConnector, error) { - PrivateKeyRSA, err := readPKCS8PrivateKey([]byte(snowflakeProtoConfig.PrivateKey)) + PrivateKeyRSA, err := util.DecodePKCS8PrivateKey([]byte(snowflakeProtoConfig.PrivateKey)) if err != nil { return nil, err } @@ -139,10 +120,16 @@ func NewSnowflakeConnector(ctx context.Context, return nil, fmt.Errorf("failed to open connection to Snowflake peer: %w", err) } + client, err := NewSnowflakeClient(ctx, snowflakeProtoConfig) + if err != nil { + return nil, fmt.Errorf("failed to create Snowflake client: %w", err) + } + return &SnowflakeConnector{ ctx: ctx, database: database, tableSchemaMapping: nil, + client: client, }, nil } diff --git a/flow/e2e/bigquery_helper.go b/flow/e2e/bigquery_helper.go index 4f023958f5..8c3d09e57d 100644 --- a/flow/e2e/bigquery_helper.go +++ b/flow/e2e/bigquery_helper.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "math/big" - "math/rand" "os" "strings" "time" @@ -14,14 +13,17 @@ import ( peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + util "github.com/PeerDB-io/peer-flow/utils" "google.golang.org/api/iterator" ) type BigQueryTestHelper struct { // runID uniquely identifies the test run to namespace stateful schemas. - runID int64 + runID uint64 // config is the BigQuery config. Config *protos.BigqueryConfig + // peer struct holder BigQuery + Peer *protos.Peer // client to talk to BigQuery client *bigquery.Client // dataset to use for testing. @@ -31,7 +33,10 @@ type BigQueryTestHelper struct { // NewBigQueryTestHelper creates a new BigQueryTestHelper. func NewBigQueryTestHelper() (*BigQueryTestHelper, error) { // random 64 bit int to namespace stateful schemas. - runID := rand.Int63() + runID, err := util.RandomUInt64() + if err != nil { + return nil, fmt.Errorf("failed to generate random uint64: %w", err) + } jsonPath := os.Getenv("TEST_BQ_CREDS") if jsonPath == "" { @@ -62,14 +67,29 @@ func NewBigQueryTestHelper() (*BigQueryTestHelper, error) { return nil, fmt.Errorf("failed to create helper BigQuery client: %v", err) } + peer := generateBQPeer(&config) + return &BigQueryTestHelper{ runID: runID, Config: &config, client: client, datasetName: config.DatasetId, + Peer: peer, }, nil } +func generateBQPeer(bigQueryConfig *protos.BigqueryConfig) *protos.Peer { + ret := &protos.Peer{} + ret.Name = "test_bq_peer" + ret.Type = protos.DBType_BIGQUERY + + ret.Config = &protos.Peer_BigqueryConfig{ + BigqueryConfig: bigQueryConfig, + } + + return ret +} + // datasetExists checks if the dataset exists. func (b *BigQueryTestHelper) datasetExists() (bool, error) { dataset := b.client.Dataset(b.Config.DatasetId) @@ -207,6 +227,65 @@ func toQValue(bqValue bigquery.Value) (model.QValue, error) { } } +// bqFieldTypeToQValueKind converts a bigquery FieldType to a QValueKind. +func bqFieldTypeToQValueKind(fieldType bigquery.FieldType) (model.QValueKind, error) { + switch fieldType { + case bigquery.StringFieldType: + return model.QValueKindString, nil + case bigquery.BytesFieldType: + return model.QValueKindBytes, nil + case bigquery.IntegerFieldType: + return model.QValueKindInt64, nil + case bigquery.FloatFieldType: + return model.QValueKindFloat64, nil + case bigquery.BooleanFieldType: + return model.QValueKindBoolean, nil + case bigquery.TimestampFieldType: + return model.QValueKindETime, nil + case bigquery.RecordFieldType: + return model.QValueKindStruct, nil + case bigquery.DateFieldType: + return model.QValueKindETime, nil + case bigquery.TimeFieldType: + return model.QValueKindETime, nil + case bigquery.NumericFieldType: + return model.QValueKindNumeric, nil + case bigquery.GeographyFieldType: + return model.QValueKindString, nil + default: + return "", fmt.Errorf("unsupported bigquery field type: %v", fieldType) + } +} + +func bqFieldSchemaToQField(fieldSchema *bigquery.FieldSchema) (*model.QField, error) { + qValueKind, err := bqFieldTypeToQValueKind(fieldSchema.Type) + if err != nil { + return nil, err + } + + return &model.QField{ + Name: fieldSchema.Name, + Type: qValueKind, + Nullable: !fieldSchema.Required, + }, nil +} + +// bqSchemaToQRecordSchema converts a bigquery schema to a QRecordSchema. +func bqSchemaToQRecordSchema(schema bigquery.Schema) (*model.QRecordSchema, error) { + var fields []*model.QField + for _, fieldSchema := range schema { + qField, err := bqFieldSchemaToQField(fieldSchema) + if err != nil { + return nil, err + } + fields = append(fields, qField) + } + + return &model.QRecordSchema{ + Fields: fields, + }, nil +} + func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecordBatch, error) { it, err := b.client.Query(query).Read(context.Background()) if err != nil { @@ -247,18 +326,64 @@ func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecor // Now you should fill the column names as well. Here we assume the schema is // retrieved from the query itself - var columnNames []string + var schema *model.QRecordSchema if it.Schema != nil { - columnNames = make([]string, len(it.Schema)) - for i, fieldSchema := range it.Schema { - columnNames[i] = fieldSchema.Name + schema, err = bqSchemaToQRecordSchema(it.Schema) + if err != nil { + return nil, err } } // Return a QRecordBatch return &model.QRecordBatch{ - NumRecords: uint32(len(records)), - Records: records, - ColumnNames: columnNames, + NumRecords: uint32(len(records)), + Records: records, + Schema: schema, }, nil } + +func qValueKindToBqColTypeString(val model.QValueKind) (string, error) { + switch val { + case model.QValueKindInt32: + return "INT64", nil + case model.QValueKindInt64: + return "INT64", nil + case model.QValueKindFloat32: + return "FLOAT64", nil + case model.QValueKindFloat64: + return "FLOAT64", nil + case model.QValueKindString: + return "STRING", nil + case model.QValueKindBoolean: + return "BOOL", nil + case model.QValueKindETime: + return "TIMESTAMP", nil + case model.QValueKindBytes: + return "BYTES", nil + case model.QValueKindNumeric: + return "NUMERIC", nil + default: + return "", fmt.Errorf("unsupported QValueKind: %v", val) + } +} + +func (b *BigQueryTestHelper) CreateTable(tableName string, schema *model.QRecordSchema) error { + var fields []string + for _, field := range schema.Fields { + bqType, err := qValueKindToBqColTypeString(field.Type) + if err != nil { + return err + } + fields = append(fields, fmt.Sprintf("%s %s", field.Name, bqType)) + } + + command := fmt.Sprintf("CREATE TABLE %s.%s (%s)", b.datasetName, tableName, strings.Join(fields, ", ")) + fmt.Printf("creating table %s with command %s\n", tableName, command) + + err := b.RunCommand(command) + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + return nil +} diff --git a/flow/e2e/congen.go b/flow/e2e/congen.go index 813f93dcf3..db02656edf 100644 --- a/flow/e2e/congen.go +++ b/flow/e2e/congen.go @@ -1,28 +1,9 @@ package e2e import ( - "fmt" - "io" - "os" - "github.com/PeerDB-io/peer-flow/generated/protos" ) -type FlowConnectionGenerationConfig struct { - FlowJobName string - TableNameMapping map[string]string - PostgresPort int - BigQueryConfig *protos.BigqueryConfig -} - -type QRepFlowConnectionGenerationConfig struct { - FlowJobName string - SourceTableIdentifier string - DestinationTableIdentifier string - PostgresPort int - BigQueryConfig *protos.BigqueryConfig -} - // GeneratePostgresPeer generates a postgres peer config for testing. func GeneratePostgresPeer(postgresPort int) *protos.Peer { ret := &protos.Peer{} @@ -42,33 +23,21 @@ func GeneratePostgresPeer(postgresPort int) *protos.Peer { return ret } -// readFileToBytes reads a file to a byte array. -func readFileToBytes(path string) ([]byte, error) { - var ret []byte - - f, err := os.Open(path) - if err != nil { - return ret, fmt.Errorf("failed to open file: %w", err) - } - - defer f.Close() - - ret, err = io.ReadAll(f) - if err != nil { - return ret, fmt.Errorf("failed to read file: %w", err) - } - - return ret, nil +type FlowConnectionGenerationConfig struct { + FlowJobName string + TableNameMapping map[string]string + PostgresPort int + Destination *protos.Peer } -// GenerateBQPeer generates a bigquery peer config for testing. -func GenerateBQPeer(bigQueryConfig *protos.BigqueryConfig) (*protos.Peer, error) { +// GenerateSnowflakePeer generates a snowflake peer config for testing. +func GenerateSnowflakePeer(snowflakeConfig *protos.SnowflakeConfig) (*protos.Peer, error) { ret := &protos.Peer{} - ret.Name = "test_bq_peer" - ret.Type = protos.DBType_BIGQUERY + ret.Name = "test_snowflake_peer" + ret.Type = protos.DBType_SNOWFLAKE - ret.Config = &protos.Peer_BigqueryConfig{ - BigqueryConfig: bigQueryConfig, + ret.Config = &protos.Peer_SnowflakeConfig{ + SnowflakeConfig: snowflakeConfig, } return ret, nil @@ -76,39 +45,33 @@ func GenerateBQPeer(bigQueryConfig *protos.BigqueryConfig) (*protos.Peer, error) func (c *FlowConnectionGenerationConfig) GenerateFlowConnectionConfigs() (*protos.FlowConnectionConfigs, error) { ret := &protos.FlowConnectionConfigs{} - ret.FlowJobName = c.FlowJobName ret.TableNameMapping = c.TableNameMapping - ret.Source = GeneratePostgresPeer(c.PostgresPort) - - bqPeer, err := GenerateBQPeer(c.BigQueryConfig) - if err != nil { - return nil, fmt.Errorf("failed to generate bq peer: %w", err) - } - - ret.Destination = bqPeer - + ret.Destination = c.Destination return ret, nil } +type QRepFlowConnectionGenerationConfig struct { + FlowJobName string + SourceTableIdentifier string + DestinationTableIdentifier string + PostgresPort int + Destination *protos.Peer +} + // GenerateQRepConfig generates a qrep config for testing. func (c *QRepFlowConnectionGenerationConfig) GenerateQRepConfig( query string, watermark string, syncMode protos.QRepSyncMode) (*protos.QRepConfig, error) { ret := &protos.QRepConfig{} - ret.FlowJobName = c.FlowJobName ret.SourceTableIdentifier = c.SourceTableIdentifier ret.DestinationTableIdentifier = c.DestinationTableIdentifier postgresPeer := GeneratePostgresPeer(c.PostgresPort) - bqPeer, err := GenerateBQPeer(c.BigQueryConfig) - if err != nil { - return nil, fmt.Errorf("failed to generate bq peer: %w", err) - } - ret.SourcePeer = postgresPeer - ret.DestinationPeer = bqPeer + + ret.DestinationPeer = c.Destination ret.Query = query ret.WatermarkColumn = watermark diff --git a/flow/e2e/peer_flow_test.go b/flow/e2e/peer_flow_test.go index ff2fc5b290..6b727cf918 100644 --- a/flow/e2e/peer_flow_test.go +++ b/flow/e2e/peer_flow_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/activities" + util "github.com/PeerDB-io/peer-flow/utils" peerflow "github.com/PeerDB-io/peer-flow/workflows" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/mock" @@ -21,6 +22,7 @@ type E2EPeerFlowTestSuite struct { pool *pgxpool.Pool bqHelper *BigQueryTestHelper + sfHelper *SnowflakeTestHelper } func TestE2EPeerFlowTestSuite(t *testing.T) { @@ -92,10 +94,38 @@ func (s *E2EPeerFlowTestSuite) setupBigQuery() error { return fmt.Errorf("failed to create bigquery helper: %w", err) } + err = bqHelper.RecreateDataset() + if err != nil { + return fmt.Errorf("failed to recreate bigquery dataset: %w", err) + } + s.bqHelper = bqHelper return nil } +// setupSnowflake sets up the snowflake connection. +func (s *E2EPeerFlowTestSuite) setupSnowflake() error { + runID, err := util.RandomUInt64() + if err != nil { + return fmt.Errorf("failed to generate random uint64: %w", err) + } + + testSchemaName := fmt.Sprintf("e2e_test_%d", runID) + + sfHelper, err := NewSnowflakeTestHelper(testSchemaName) + if err != nil { + return fmt.Errorf("failed to create snowflake helper: %w", err) + } + + err = sfHelper.RecreateSchema() + if err != nil { + return fmt.Errorf("failed to recreate snowflake schema: %w", err) + } + + s.sfHelper = sfHelper + return nil +} + // Implement SetupAllSuite interface to setup the test suite func (s *E2EPeerFlowTestSuite) SetupSuite() { // seed the random number generator with current time @@ -112,9 +142,9 @@ func (s *E2EPeerFlowTestSuite) SetupSuite() { s.Fail("failed to setup bigquery", err) } - err = s.bqHelper.RecreateDataset() + err = s.setupSnowflake() if err != nil { - s.Fail("failed to recreate bigquery dataset", err) + s.Fail("failed to setup snowflake", err) } } @@ -134,6 +164,13 @@ func (s *E2EPeerFlowTestSuite) TearDownSuite() { if err != nil { s.Fail("failed to drop bigquery dataset", err) } + + if s.sfHelper != nil { + err = s.sfHelper.DropSchema() + if err != nil { + s.Fail("failed to drop snowflake schema", err) + } + } } func registerWorkflowsAndActivities(env *testsuite.TestWorkflowEnvironment) { @@ -194,7 +231,7 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_Flow_No_Data() { FlowJobName: "test_complete_flow_no_data", TableNameMapping: map[string]string{"e2e_test.test": "test"}, PostgresPort: postgresPort, - BigQueryConfig: s.bqHelper.Config, + Destination: s.bqHelper.Peer, } flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() @@ -238,7 +275,7 @@ func (s *E2EPeerFlowTestSuite) Test_Char_ColType_Error() { FlowJobName: "test_char_table", TableNameMapping: map[string]string{"e2e_test.test_char_table": "test"}, PostgresPort: postgresPort, - BigQueryConfig: s.bqHelper.Config, + Destination: s.bqHelper.Peer, } flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() @@ -285,7 +322,7 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_Simple_Flow() { FlowJobName: "test_complete_single_col_flow", TableNameMapping: map[string]string{"e2e_test.test_simple_flow": "test_simple_flow"}, PostgresPort: postgresPort, - BigQueryConfig: s.bqHelper.Config, + Destination: s.bqHelper.Peer, } flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() diff --git a/flow/e2e/qrep_flow_test.go b/flow/e2e/qrep_flow_test.go index 16ecbe86b3..6ca5e15f8d 100644 --- a/flow/e2e/qrep_flow_test.go +++ b/flow/e2e/qrep_flow_test.go @@ -7,6 +7,7 @@ import ( connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" peerflow "github.com/PeerDB-io/peer-flow/workflows" "github.com/google/uuid" ) @@ -79,65 +80,86 @@ func (s *E2EPeerFlowTestSuite) setupSourceTable(tableName string, rowCount int) } } -func (s *E2EPeerFlowTestSuite) setupDestinationTable(dstTable string) { - dstTableName := fmt.Sprintf("%s.%s", s.bqHelper.Config.DatasetId, dstTable) - colWithTypes := []string{ - "id STRING", - "card_id STRING", - "from_v TIMESTAMP", - "price NUMERIC", - "created_at TIMESTAMP", - "updated_at TIMESTAMP", - "transaction_hash BYTES", - "ownerable_type STRING", - "ownerable_id STRING", - "user_nonce INT64", - "transfer_type INT64", - "blockchain INT64", - "deal_type STRING", - "deal_id STRING", - "ethereum_transaction_id STRING", - "ignore_price BOOL", - "card_eth_value FLOAT64", - "paid_eth_price FLOAT64", - "card_bought_notified BOOL", - "address NUMERIC", - "account_id STRING", - "asset_id NUMERIC", - "status INT64", - "transaction_id STRING", - "settled_at TIMESTAMP", - "reference_id STRING", - "settle_at TIMESTAMP", - "settlement_delay_reason INT64", +func getOwnersSchema() *model.QRecordSchema { + return &model.QRecordSchema{ + Fields: []*model.QField{ + {Name: "id", Type: model.QValueKindString, Nullable: true}, + {Name: "card_id", Type: model.QValueKindString, Nullable: true}, + {Name: "from_v", Type: model.QValueKindETime, Nullable: true}, + {Name: "price", Type: model.QValueKindNumeric, Nullable: true}, + {Name: "created_at", Type: model.QValueKindETime, Nullable: true}, + {Name: "updated_at", Type: model.QValueKindETime, Nullable: true}, + {Name: "transaction_hash", Type: model.QValueKindBytes, Nullable: true}, + {Name: "ownerable_type", Type: model.QValueKindString, Nullable: true}, + {Name: "ownerable_id", Type: model.QValueKindString, Nullable: true}, + {Name: "user_nonce", Type: model.QValueKindInt64, Nullable: true}, + {Name: "transfer_type", Type: model.QValueKindInt64, Nullable: true}, + {Name: "blockchain", Type: model.QValueKindInt64, Nullable: true}, + {Name: "deal_type", Type: model.QValueKindString, Nullable: true}, + {Name: "deal_id", Type: model.QValueKindString, Nullable: true}, + {Name: "ethereum_transaction_id", Type: model.QValueKindString, Nullable: true}, + {Name: "ignore_price", Type: model.QValueKindBoolean, Nullable: true}, + {Name: "card_eth_value", Type: model.QValueKindFloat64, Nullable: true}, + {Name: "paid_eth_price", Type: model.QValueKindFloat64, Nullable: true}, + {Name: "card_bought_notified", Type: model.QValueKindBoolean, Nullable: true}, + {Name: "address", Type: model.QValueKindNumeric, Nullable: true}, + {Name: "account_id", Type: model.QValueKindString, Nullable: true}, + {Name: "asset_id", Type: model.QValueKindNumeric, Nullable: true}, + {Name: "status", Type: model.QValueKindInt64, Nullable: true}, + {Name: "transaction_id", Type: model.QValueKindString, Nullable: true}, + {Name: "settled_at", Type: model.QValueKindETime, Nullable: true}, + {Name: "reference_id", Type: model.QValueKindString, Nullable: true}, + {Name: "settle_at", Type: model.QValueKindETime, Nullable: true}, + {Name: "settlement_delay_reason", Type: model.QValueKindInt64, Nullable: true}, + }, } +} - dstTableCmd := fmt.Sprintf( - "CREATE TABLE %s (%s)", - dstTableName, - strings.Join(colWithTypes, ","), - ) - err := s.bqHelper.RunCommand(dstTableCmd) +func getOwnersSelectorString() string { + schema := getOwnersSchema() + var fields []string + for _, field := range schema.Fields { + fields = append(fields, field.Name) + } + return strings.Join(fields, ",") +} + +func (s *E2EPeerFlowTestSuite) setupBQDestinationTable(dstTable string) { + schema := getOwnersSchema() + err := s.bqHelper.CreateTable(dstTable, schema) // fail if table creation fails s.NoError(err) - fmt.Printf("created table on bigquery: %s. %v\n", dstTableName, err) + fmt.Printf("created table on bigquery: %s.%s. %v\n", s.bqHelper.Config.DatasetId, dstTable, err) +} + +func (s *E2EPeerFlowTestSuite) setupSFDestinationTable(dstTable string) { + schema := getOwnersSchema() + err := s.sfHelper.CreateTable(dstTable, schema) + + // fail if table creation fails + if err != nil { + s.Fail("unable to create table on snowflake", err) + } + + fmt.Printf("created table on snowflake: %s.%s. %v\n", s.sfHelper.testSchemaName, dstTable, err) } -func (s *E2EPeerFlowTestSuite) createWorkflowConfig( +func (s *E2EPeerFlowTestSuite) createQRepWorkflowConfig( flowJobName string, sourceTable string, dstTable string, query string, syncMode protos.QRepSyncMode, + dest *protos.Peer, ) *protos.QRepConfig { connectionGen := QRepFlowConnectionGenerationConfig{ FlowJobName: flowJobName, SourceTableIdentifier: sourceTable, DestinationTableIdentifier: dstTable, PostgresPort: postgresPort, - BigQueryConfig: s.bqHelper.Config, + Destination: dest, } watermark := "updated_at" @@ -150,7 +172,7 @@ func (s *E2EPeerFlowTestSuite) createWorkflowConfig( return qrepConfig } -func (s *E2EPeerFlowTestSuite) compareTableContents(tableName string) { +func (s *E2EPeerFlowTestSuite) compareTableContentsBQ(tableName string) { // read rows from source table pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background()) pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( @@ -168,6 +190,24 @@ func (s *E2EPeerFlowTestSuite) compareTableContents(tableName string) { s.True(pgRows.Equals(bqRows), "rows from source and destination tables are not equal") } +func (s *E2EPeerFlowTestSuite) compareTableContentsSF(tableName string, selector string) { + // read rows from source table + pgQueryExecutor := connpostgres.NewQRepQueryExecutor(s.pool, context.Background()) + pgRows, err := pgQueryExecutor.ExecuteAndProcessQuery( + fmt.Sprintf("SELECT %s FROM e2e_test.%s ORDER BY id", selector, tableName), + ) + s.NoError(err) + + // read rows from destination table + qualifiedTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) + sfRows, err := s.sfHelper.ExecuteAndProcessQuery( + fmt.Sprintf("SELECT %s FROM %s ORDER BY id", selector, qualifiedTableName), + ) + s.NoError(err) + + s.True(pgRows.Equals(sfRows), "rows from source and destination tables are not equal") +} + // Test_Complete_QRep_Flow tests a complete flow with data in the source table. // The test inserts 10 rows into the source table and verifies that the data is // correctly synced to the destination table this runs a QRep Flow. @@ -179,13 +219,14 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Multi_Insert() { tblName := "test_qrep_flow_multi_insert" s.setupSourceTable(tblName, numRows) - s.setupDestinationTable(tblName) + s.setupBQDestinationTable(tblName) - qrepConfig := s.createWorkflowConfig("test_qrep_flow_mi", + qrepConfig := s.createQRepWorkflowConfig("test_qrep_flow_mi", "e2e_test."+tblName, tblName, "SELECT * FROM e2e_test."+tblName, - protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT) + protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT, + s.bqHelper.Peer) env.ExecuteWorkflow(peerflow.QRepFlowWorkflow, qrepConfig) // Verify workflow completes without error @@ -211,13 +252,49 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro() { tblName := "test_qrep_flow_avro" s.setupSourceTable(tblName, numRows) - s.setupDestinationTable(tblName) + s.setupBQDestinationTable(tblName) - qrepConfig := s.createWorkflowConfig("test_qrep_flow_avro", + qrepConfig := s.createQRepWorkflowConfig("test_qrep_flow_avro", "e2e_test."+tblName, tblName, "SELECT * FROM e2e_test."+tblName, - protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO) + protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO, + s.bqHelper.Peer) + env.ExecuteWorkflow(peerflow.QRepFlowWorkflow, qrepConfig) + + // Verify workflow completes without error + s.True(env.IsWorkflowCompleted()) + + // assert that error contains "invalid connection configs" + err := env.GetWorkflowError() + s.NoError(err) + + s.compareTableContentsBQ(tblName) + + env.AssertExpectations(s.T()) +} + +func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro_SF() { + env := s.NewTestWorkflowEnvironment() + registerWorkflowsAndActivities(env) + + numRows := 1 + + tblName := "test_qrep_flow_avro_sf" + s.setupSourceTable(tblName, numRows) + s.setupSFDestinationTable(tblName) + + dstSchemaQualified := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tblName) + + qrepConfig := s.createQRepWorkflowConfig( + "test_qrep_flow_avro_Sf", + "e2e_test."+tblName, + dstSchemaQualified, + "SELECT * FROM e2e_test."+tblName, + protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO, + s.sfHelper.Peer, + ) + env.ExecuteWorkflow(peerflow.QRepFlowWorkflow, qrepConfig) // Verify workflow completes without error @@ -227,7 +304,8 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro() { err := env.GetWorkflowError() s.NoError(err) - s.compareTableContents(tblName) + sel := getOwnersSelectorString() + s.compareTableContentsSF(tblName, sel) env.AssertExpectations(s.T()) } diff --git a/flow/e2e/snowflake_helper.go b/flow/e2e/snowflake_helper.go new file mode 100644 index 0000000000..ebea957001 --- /dev/null +++ b/flow/e2e/snowflake_helper.go @@ -0,0 +1,95 @@ +package e2e + +import ( + "context" + "encoding/json" + "fmt" + "os" + + connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" +) + +type SnowflakeTestHelper struct { + // config is the Snowflake config. + Config *protos.SnowflakeConfig + // peer struct holder Snowflake + Peer *protos.Peer + // connection to Snowflake + client *connsnowflake.SnowflakeClient + // testSchemaName is the schema to use for testing. + testSchemaName string +} + +func NewSnowflakeTestHelper(testSchemaName string) (*SnowflakeTestHelper, error) { + jsonPath := os.Getenv("TEST_SF_CREDS") + if jsonPath == "" { + return nil, fmt.Errorf("TEST_SF_CREDS env var not set") + } + + content, err := readFileToBytes(jsonPath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + var config protos.SnowflakeConfig + err = json.Unmarshal(content, &config) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal json: %w", err) + } + + peer := generateSFPeer(&config) + + client, err := connsnowflake.NewSnowflakeClient(context.Background(), &config) + if err != nil { + return nil, fmt.Errorf("failed to create Snowflake client: %w", err) + } + + return &SnowflakeTestHelper{ + Config: &config, + Peer: peer, + client: client, + testSchemaName: testSchemaName, + }, nil +} + +func generateSFPeer(snowflakeConfig *protos.SnowflakeConfig) *protos.Peer { + ret := &protos.Peer{} + ret.Name = "test_sf_peer" + ret.Type = protos.DBType_SNOWFLAKE + + ret.Config = &protos.Peer_SnowflakeConfig{ + SnowflakeConfig: snowflakeConfig, + } + + return ret +} + +// RecreateSchema recreates the schema, i.e., drops it if exists and creates it again. +func (s *SnowflakeTestHelper) RecreateSchema() error { + return s.client.RecreateSchema(s.testSchemaName) +} + +// DropSchema drops the schema. +func (s *SnowflakeTestHelper) DropSchema() error { + return s.client.DropSchema(s.testSchemaName) +} + +// RunCommand runs the given command. +func (s *SnowflakeTestHelper) RunCommand(command string) error { + return s.client.RunCommand(command) +} + +// CountRows(tableName) returns the number of rows in the given table. +func (s *SnowflakeTestHelper) CountRows(tableName string) (int, error) { + return s.client.CountRows(s.testSchemaName, tableName) +} + +func (s *SnowflakeTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecordBatch, error) { + return s.client.ExecuteAndProcessQuery(query) +} + +func (s *SnowflakeTestHelper) CreateTable(tableName string, schema *model.QRecordSchema) error { + return s.client.CreateTable(schema, s.testSchemaName, tableName) +} diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go new file mode 100644 index 0000000000..293af9095f --- /dev/null +++ b/flow/e2e/test_utils.go @@ -0,0 +1,26 @@ +package e2e + +import ( + "fmt" + "io" + "os" +) + +// readFileToBytes reads a file to a byte array. +func readFileToBytes(path string) ([]byte, error) { + var ret []byte + + f, err := os.Open(path) + if err != nil { + return ret, fmt.Errorf("failed to open file: %w", err) + } + + defer f.Close() + + ret, err = io.ReadAll(f) + if err != nil { + return ret, fmt.Errorf("failed to read file: %w", err) + } + + return ret, nil +} diff --git a/flow/go.mod b/flow/go.mod index 2ef1142d56..266981f154 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -10,6 +10,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/jackc/pglogrepl v0.0.0-20230428004623-0c5b98f52784 github.com/jackc/pgx/v5 v5.3.1 + github.com/jmoiron/sqlx v1.3.5 github.com/linkedin/goavro v2.1.0+incompatible github.com/linkedin/goavro/v2 v2.12.0 github.com/sirupsen/logrus v1.9.3 diff --git a/flow/go.sum b/flow/go.sum index 1da1a6abc0..8bea0bea5a 100644 --- a/flow/go.sum +++ b/flow/go.sum @@ -773,6 +773,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k= github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= @@ -939,6 +941,8 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -970,6 +974,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/linkedin/goavro v2.1.0+incompatible h1:DV2aUlj2xZiuxQyvag8Dy7zjY69ENjS66bWkSfdpddY= github.com/linkedin/goavro v2.1.0+incompatible/go.mod h1:bBCwI2eGYpUI/4820s67MElg9tdeLbINjLjiM2xZFYM= github.com/linkedin/goavro/v2 v2.12.0 h1:rIQQSj8jdAUlKQh6DttK8wCRv4t4QO09g1C4aBWXslg= @@ -981,6 +987,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= diff --git a/flow/model/qdb_type.go b/flow/model/qdb_type.go new file mode 100644 index 0000000000..b8ad665900 --- /dev/null +++ b/flow/model/qdb_type.go @@ -0,0 +1,10 @@ +package model + +type QDBType int + +const ( + QDBTypeUnknown QDBType = 0 + QDBTypePostgres QDBType = 1 + QDBTypeSnowflake QDBType = 2 + QDBTypeBigQuery QDBType = 3 +) diff --git a/flow/model/qrecord.go b/flow/model/qrecord.go index bf244b74bf..d7b67d949a 100644 --- a/flow/model/qrecord.go +++ b/flow/model/qrecord.go @@ -32,6 +32,7 @@ func (q *QRecord) equals(other *QRecord) bool { for i, entry := range q.Entries { otherEntry := other.Entries[i] if !entry.Equals(&otherEntry) { + fmt.Printf("entry %d: %v != %v\n", i, entry, otherEntry) return false } } @@ -40,6 +41,7 @@ func (q *QRecord) equals(other *QRecord) bool { } func (q *QRecord) ToAvroCompatibleMap( + targetDB QDBType, nullableFields *map[string]bool, colNames []string, ) (map[string]interface{}, error) { @@ -48,7 +50,7 @@ func (q *QRecord) ToAvroCompatibleMap( for idx, qValue := range q.Entries { key := colNames[idx] nullable, ok := (*nullableFields)[key] - avroVal, err := qValue.ToAvroValue(nullable && ok) + avroVal, err := qValue.ToAvroValue(targetDB, nullable && ok) if err != nil { return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %v", err) } diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 7777fb6e45..88b2d8168f 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -2,9 +2,9 @@ package model // QRecordBatch holds a batch of QRecord objects. type QRecordBatch struct { - NumRecords uint32 // NumRecords represents the number of records in the batch. - Records []*QRecord // Records is a slice of pointers to QRecord objects. - ColumnNames []string // ColumnNames is a slice of column names. + NumRecords uint32 // NumRecords represents the number of records in the batch. + Records []*QRecord // Records is a slice of pointers to QRecord objects. + Schema *QRecordSchema } // Equals checks if two QRecordBatches are identical. @@ -14,15 +14,13 @@ func (q *QRecordBatch) Equals(other *QRecordBatch) bool { } // First check simple attributes - if q.NumRecords != other.NumRecords || len(q.ColumnNames) != len(other.ColumnNames) { + if q.NumRecords != other.NumRecords { return false } // Compare column names - for i, colName := range q.ColumnNames { - if colName != other.ColumnNames[i] { - return false - } + if !q.Schema.EqualNames(other.Schema) { + return false } // Compare records diff --git a/flow/model/qschema.go b/flow/model/qschema.go new file mode 100644 index 0000000000..eb6811ed15 --- /dev/null +++ b/flow/model/qschema.go @@ -0,0 +1,51 @@ +package model + +import "strings" + +type QField struct { + Name string + Type QValueKind + Nullable bool +} + +type QRecordSchema struct { + Fields []*QField +} + +// NewQRecordSchema creates a new QRecordSchema. +func NewQRecordSchema(fields []*QField) *QRecordSchema { + return &QRecordSchema{ + Fields: fields, + } +} + +// EqualNames returns true if the field names are equal. +func (q *QRecordSchema) EqualNames(other *QRecordSchema) bool { + if other == nil { + return q == nil + } + + if len(q.Fields) != len(other.Fields) { + return false + } + + for i, field := range q.Fields { + // ignore the case of the field name convert to lower case + f1 := strings.ToLower(field.Name) + f2 := strings.ToLower(other.Fields[i].Name) + if f1 != f2 { + return false + } + } + + return true +} + +// GetColumnNames returns a slice of column names. +func (q *QRecordSchema) GetColumnNames() []string { + var names []string + for _, field := range q.Fields { + names = append(names, field.Name) + } + return names +} diff --git a/flow/model/qschema_avro.go b/flow/model/qschema_avro.go new file mode 100644 index 0000000000..63dbc86f35 --- /dev/null +++ b/flow/model/qschema_avro.go @@ -0,0 +1,141 @@ +package model + +import ( + "encoding/json" + "fmt" +) + +type QRecordAvroField struct { + Name string `json:"name"` + Type interface{} `json:"type"` +} + +type QRecordAvroSchema struct { + Type string `json:"type"` + Name string `json:"name"` + Fields []QRecordAvroField `json:"fields"` +} + +type QRecordAvroSchemaDefinition struct { + Schema string + NullableFields map[string]bool +} + +func GetAvroSchemaDefinition(dstTableName string, qRecordSchema *QRecordSchema) (*QRecordAvroSchemaDefinition, error) { + avroFields := []QRecordAvroField{} + nullableFields := map[string]bool{} + + for _, qField := range qRecordSchema.Fields { + avroType, err := GetAvroType(qField) + if err != nil { + return nil, err + } + + consolidatedType := avroType.AType + + if avroType.RespectNull && qField.Nullable { + consolidatedType = []interface{}{"null", consolidatedType} + nullableFields[qField.Name] = true + } + + avroFields = append(avroFields, QRecordAvroField{ + Name: qField.Name, + Type: consolidatedType, + }) + } + + avroSchema := QRecordAvroSchema{ + Type: "record", + Name: dstTableName, + Fields: avroFields, + } + + avroSchemaJSON, err := json.Marshal(avroSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal Avro schema to JSON: %v", err) + } + + return &QRecordAvroSchemaDefinition{ + Schema: string(avroSchemaJSON), + NullableFields: nullableFields, + }, nil +} + +type avroType struct { + AType interface{} + RespectNull bool +} + +func GetAvroType(qField *QField) (*avroType, error) { + switch qField.Type { + case QValueKindString: + return &avroType{ + AType: "string", + RespectNull: qField.Nullable, + }, nil + case QValueKindInt16, QValueKindInt32: + return &avroType{ + AType: "long", + RespectNull: qField.Nullable, + }, nil + case QValueKindInt64: + return &avroType{ + AType: "long", + RespectNull: qField.Nullable, + }, nil + case QValueKindFloat16, QValueKindFloat32: + return &avroType{ + AType: "float", + RespectNull: qField.Nullable, + }, nil + case QValueKindFloat64: + return &avroType{ + AType: "double", + RespectNull: qField.Nullable, + }, nil + case QValueKindBoolean: + return &avroType{ + AType: "boolean", + RespectNull: qField.Nullable, + }, nil + case QValueKindBytes: + return &avroType{ + AType: "bytes", + RespectNull: qField.Nullable, + }, nil + case QValueKindNumeric: + // For the case of numeric values, you may need to be more specific depending + // on the range and precision of your numeric data. + return &avroType{ + AType: map[string]interface{}{ + "type": "bytes", + "logicalType": "decimal", + "precision": 38, + "scale": 9, + }, + RespectNull: false, + }, nil + case QValueKindUUID: + // treat UUID as a string + return &avroType{ + AType: "string", + RespectNull: qField.Nullable, + }, nil + case QValueKindETime: + return &avroType{ + AType: map[string]string{ + "type": "long", + "logicalType": "timestamp-micros", + }, + RespectNull: false, + }, nil + case QValueKindJSON, QValueKindArray, QValueKindStruct: + // Handling complex types like JSON, Array, and Struct may require more complex logic + return nil, fmt.Errorf("complex types not supported yet: %s", qField.Type) + case QValueKindBit: + // Bit types may need their own specific handling or may not map directly to Avro types + return nil, fmt.Errorf("unsupported QField type: %s", qField.Type) + default: + return nil, fmt.Errorf("unsupported QField type: %s", qField.Type) + } +} diff --git a/flow/model/qvalue.go b/flow/model/qvalue.go index ab2f8b13ed..2e744e70d4 100644 --- a/flow/model/qvalue.go +++ b/flow/model/qvalue.go @@ -7,45 +7,22 @@ import ( "math/big" "reflect" "strconv" - "time" "github.com/google/uuid" "github.com/linkedin/goavro" ) -type QValueKind string - -const ( - QValueKindInvalid QValueKind = "invalid" - QValueKindFloat16 QValueKind = "float16" - QValueKindFloat32 QValueKind = "float32" - QValueKindFloat64 QValueKind = "float64" - QValueKindInt16 QValueKind = "int16" - QValueKindInt32 QValueKind = "int32" - QValueKindInt64 QValueKind = "int64" - QValueKindBoolean QValueKind = "bool" - QValueKindArray QValueKind = "array" - QValueKindStruct QValueKind = "struct" - QValueKindString QValueKind = "string" - QValueKindETime QValueKind = "extended_time" - QValueKindNumeric QValueKind = "numeric" - QValueKindBytes QValueKind = "bytes" - QValueKindUUID QValueKind = "uuid" - QValueKindJSON QValueKind = "json" - QValueKindBit QValueKind = "bit" -) - type QValue struct { Kind QValueKind Value interface{} } -func (q *QValue) ToAvroValue(isNullable bool) (interface{}, error) { +func (q *QValue) ToAvroValue(targetDB QDBType, isNullable bool) (interface{}, error) { switch q.Kind { case QValueKindInvalid: return nil, fmt.Errorf("invalid QValueKind") case QValueKindETime: - return processExtendedTime(q) + return processExtendedTime(targetDB, q) case QValueKindString: return processNullableUnion(isNullable, "string", q.Value) case QValueKindFloat16, QValueKindFloat32, QValueKindFloat64: @@ -75,7 +52,7 @@ func (q *QValue) ToAvroValue(isNullable bool) (interface{}, error) { } } -func processExtendedTime(q *QValue) (interface{}, error) { +func processExtendedTime(targetDB QDBType, q *QValue) (interface{}, error) { et, ok := q.Value.(*ExtendedTime) if !ok { return nil, fmt.Errorf("invalid ExtendedTime value") @@ -83,11 +60,19 @@ func processExtendedTime(q *QValue) (interface{}, error) { switch et.NestedKind.Type { case DateTimeKindType: - return et.Time.UnixNano() / int64(time.Millisecond), nil + ret := et.Time.UnixMicro() + // Snowflake has issues with avro timestamp types + // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file + if targetDB == QDBTypeSnowflake { + ret = ret / 1000000 + } + return ret, nil case DateKindType: - return et.Time.Format("2006-01-02"), nil + ret := et.Time.Format("2006-01-02") + return ret, nil case TimeKindType: - return et.Time.Format("15:04:05.999999"), nil + ret := et.Time.Format("15:04:05.999999") + return ret, nil default: return nil, fmt.Errorf("unsupported ExtendedTimeKindType: %s", et.NestedKind.Type) } @@ -225,8 +210,10 @@ func compareETime(value1, value2 interface{}) bool { return false } - t1 := et1.Time.UnixNano() / int64(time.Millisecond) - t2 := et2.Time.UnixNano() / int64(time.Millisecond) + // TODO: this is a hack, we should be comparing the actual time values + // currently this is only used for testing so that is OK. + t1 := et1.Time.UnixMilli() / 1000 + t2 := et2.Time.UnixMilli() / 1000 return t1 == t2 } @@ -358,6 +345,8 @@ func getInt32(v interface{}) (int32, bool) { return value, true case int64: return int32(value), true + case *big.Rat: + return int32(value.Num().Int64()), true case string: parsed, err := strconv.ParseInt(value, 10, 32) if err == nil { @@ -373,6 +362,8 @@ func getInt64(v interface{}) (int64, bool) { return value, true case int32: return int64(value), true + case *big.Rat: + return value.Num().Int64(), true case string: parsed, err := strconv.ParseInt(value, 10, 64) if err == nil { diff --git a/flow/model/qvalue_kind.go b/flow/model/qvalue_kind.go new file mode 100644 index 0000000000..e95a53267f --- /dev/null +++ b/flow/model/qvalue_kind.go @@ -0,0 +1,23 @@ +package model + +type QValueKind string + +const ( + QValueKindInvalid QValueKind = "invalid" + QValueKindFloat16 QValueKind = "float16" + QValueKindFloat32 QValueKind = "float32" + QValueKindFloat64 QValueKind = "float64" + QValueKindInt16 QValueKind = "int16" + QValueKindInt32 QValueKind = "int32" + QValueKindInt64 QValueKind = "int64" + QValueKindBoolean QValueKind = "bool" + QValueKindArray QValueKind = "array" + QValueKindStruct QValueKind = "struct" + QValueKindString QValueKind = "string" + QValueKindETime QValueKind = "extended_time" + QValueKindNumeric QValueKind = "numeric" + QValueKindBytes QValueKind = "bytes" + QValueKindUUID QValueKind = "uuid" + QValueKindJSON QValueKind = "json" + QValueKindBit QValueKind = "bit" +) diff --git a/flow/utils/crypto.go b/flow/utils/crypto.go new file mode 100644 index 0000000000..8a84125828 --- /dev/null +++ b/flow/utils/crypto.go @@ -0,0 +1,25 @@ +package util + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" +) + +func DecodePKCS8PrivateKey(rawKey []byte) (*rsa.PrivateKey, error) { + PEMBlock, _ := pem.Decode(rawKey) + if PEMBlock == nil { + return nil, fmt.Errorf("failed to decode private key PEM block") + } + privateKeyAny, err := x509.ParsePKCS8PrivateKey(PEMBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key PEM block as PKCS8: %w", err) + } + privateKeyRSA, ok := privateKeyAny.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("key does not appear to RSA as expected") + } + + return privateKeyRSA, nil +} diff --git a/flow/utils/random.go b/flow/utils/random.go new file mode 100644 index 0000000000..a3c22330ca --- /dev/null +++ b/flow/utils/random.go @@ -0,0 +1,29 @@ +package util + +import ( + "crypto/rand" + "encoding/binary" + "errors" +) + +// RandomInt64 returns a random 64 bit integer. +func RandomInt64() (int64, error) { + b := make([]byte, 8) + _, err := rand.Read(b) + if err != nil { + return 0, errors.New("could not generate random int64: " + err.Error()) + } + // Convert bytes to int64 + return int64(binary.LittleEndian.Uint64(b)), nil +} + +// RandomUInt64 returns a random 64 bit unsigned integer. +func RandomUInt64() (uint64, error) { + b := make([]byte, 8) + _, err := rand.Read(b) + if err != nil { + return 0, errors.New("could not generate random uint64: " + err.Error()) + } + // Convert bytes to uint64 + return binary.LittleEndian.Uint64(b), nil +} diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index 391e8aed10..e8ad22f602 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -70,6 +70,9 @@ func (q *QRepFlowExecution) ReplicatePartition(ctx workflow.Context, partition * ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ StartToCloseTimeout: 15 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 2, + }, }) if err := workflow.ExecuteActivity(ctx,