From 93d754a405cd672baa2ca3ea3c9388cb575fbb15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Mon, 25 Dec 2023 20:45:46 +0000 Subject: [PATCH 1/2] Fix golangci-lint failures (#902) --- flow/.golangci.yml | 8 ++++++-- flow/activities/slot.go | 2 +- flow/cmd/api.go | 5 ++++- flow/cmd/handler.go | 4 +++- flow/cmd/peer_data.go | 16 ++++++++-------- flow/cmd/snapshot_worker.go | 5 ++++- flow/cmd/worker.go | 5 ++++- flow/connectors/postgres/cdc.go | 4 ++-- flow/connectors/postgres/client.go | 6 +++--- .../snowflake/avro_file_writer_test.go | 1 + flow/connectors/utils/ssh.go | 5 +++-- flow/e2e/snowflake/qrep_flow_sf_test.go | 1 + 12 files changed, 40 insertions(+), 22 deletions(-) diff --git a/flow/.golangci.yml b/flow/.golangci.yml index 2c8032f91d..ab612fb893 100644 --- a/flow/.golangci.yml +++ b/flow/.golangci.yml @@ -4,7 +4,6 @@ run: linters: enable: - dogsled - - dupl - gofumpt - gosec - gosimple @@ -18,9 +17,14 @@ linters: - prealloc - staticcheck - ineffassign + - unparam - unused - lll linters-settings: + stylecheck: + checks: + - all + - '-ST1003' lll: - line-length: 120 + line-length: 144 tab-width: 4 diff --git a/flow/activities/slot.go b/flow/activities/slot.go index 117dbecea3..baa0fbc0fa 100644 --- a/flow/activities/slot.go +++ b/flow/activities/slot.go @@ -23,7 +23,7 @@ func (a *FlowableActivity) handleSlotInfo( return err } - if slotInfo == nil || len(slotInfo) == 0 { + if len(slotInfo) == 0 { slog.WarnContext(ctx, "warning: unable to get slot info", slog.Any("slotName", slotName)) return nil } diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 09185a0fc1..b16034bf20 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -104,7 +104,10 @@ func APIMain(args *APIServerParams) error { } connOptions := client.ConnectionOptions{ - TLS: &tls.Config{Certificates: certs}, + TLS: &tls.Config{ + Certificates: certs, + MinVersion: tls.VersionTLS13, + }, } clientOptions.ConnectionOptions = connOptions } diff --git a/flow/cmd/handler.go b/flow/cmd/handler.go index 7b03c9de67..dd922ba1f5 100644 --- a/flow/cmd/handler.go +++ b/flow/cmd/handler.go @@ -261,7 +261,9 @@ func (h *FlowRequestHandler) CreateQRepFlow( slog.Any("error", err), slog.String("flowName", cfg.FlowJobName)) return nil, fmt.Errorf("invalid xmin txid for xmin rep: %w", err) } - state.LastPartition.Range = &protos.PartitionRange{Range: &protos.PartitionRange_IntRange{IntRange: &protos.IntPartitionRange{Start: txid}}} + state.LastPartition.Range = &protos.PartitionRange{ + Range: &protos.PartitionRange_IntRange{IntRange: &protos.IntPartitionRange{Start: txid}}, + } } workflowFn = peerflow.XminFlowWorkflow diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index 34f31219ed..f9383d8c5e 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -31,24 +31,24 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin return &pgPeerConfig, nil } -func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, string, error) { +func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, error) { pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName) if err != nil { - return nil, "", err + return nil, err } connStr := utils.GetPGConnectionString(pgPeerConfig) peerPool, err := pgxpool.New(ctx, connStr) if err != nil { - return nil, "", err + return nil, err } - return peerPool, pgPeerConfig.User, nil + return peerPool, nil } func (h *FlowRequestHandler) GetSchemas( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.PeerSchemasResponse, error) { - peerPool, _, err := h.getPoolForPGPeer(ctx, req.PeerName) + peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) if err != nil { return &protos.PeerSchemasResponse{Schemas: nil}, err } @@ -78,7 +78,7 @@ func (h *FlowRequestHandler) GetTablesInSchema( ctx context.Context, req *protos.SchemaTablesRequest, ) (*protos.SchemaTablesResponse, error) { - peerPool, _, err := h.getPoolForPGPeer(ctx, req.PeerName) + peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) if err != nil { return &protos.SchemaTablesResponse{Tables: nil}, err } @@ -110,7 +110,7 @@ func (h *FlowRequestHandler) GetAllTables( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.AllTablesResponse, error) { - peerPool, _, err := h.getPoolForPGPeer(ctx, req.PeerName) + peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) if err != nil { return &protos.AllTablesResponse{Tables: nil}, err } @@ -140,7 +140,7 @@ func (h *FlowRequestHandler) GetColumns( ctx context.Context, req *protos.TableColumnsRequest, ) (*protos.TableColumnsResponse, error) { - peerPool, _, err := h.getPoolForPGPeer(ctx, req.PeerName) + peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) if err != nil { return &protos.TableColumnsResponse{Columns: nil}, err } diff --git a/flow/cmd/snapshot_worker.go b/flow/cmd/snapshot_worker.go index 16008cc6a5..c68d44d925 100644 --- a/flow/cmd/snapshot_worker.go +++ b/flow/cmd/snapshot_worker.go @@ -32,7 +32,10 @@ func SnapshotWorkerMain(opts *SnapshotWorkerOptions) error { } connOptions := client.ConnectionOptions{ - TLS: &tls.Config{Certificates: certs}, + TLS: &tls.Config{ + Certificates: certs, + MinVersion: tls.VersionTLS13, + }, } clientOptions.ConnectionOptions = connOptions } diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index eea0e9184f..f060230b63 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -100,7 +100,10 @@ func WorkerMain(opts *WorkerOptions) error { return fmt.Errorf("unable to process certificate and key: %w", err) } connOptions := client.ConnectionOptions{ - TLS: &tls.Config{Certificates: certs}, + TLS: &tls.Config{ + Certificates: certs, + MinVersion: tls.VersionTLS13, + }, } clientOptions.ConnectionOptions = connOptions } diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index b3686f4d09..2be3fcb2a5 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -360,7 +360,6 @@ func (p *PostgresCDCSource) consumeStream( p.logger.Debug(fmt.Sprintf("XLogData => WALStart %s ServerWALEnd %s ServerTime %s\n", xld.WALStart, xld.ServerWALEnd, xld.ServerTime)) rec, err := p.processMessage(records, xld, clientXLogPos) - if err != nil { return fmt.Errorf("error processing message: %w", err) } @@ -470,7 +469,8 @@ func (p *PostgresCDCSource) consumeStream( } func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pglogrepl.XLogData, - currentClientXlogPos pglogrepl.LSN) (model.Record, error) { + currentClientXlogPos pglogrepl.LSN, +) (model.Record, error) { logicalMsg, err := pglogrepl.Parse(xld.WALData) if err != nil { return nil, fmt.Errorf("error parsing logical message: %w", err) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index dc604d5631..e48c71b29d 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -84,9 +84,9 @@ type ReplicaIdentityType rune const ( ReplicaIdentityDefault ReplicaIdentityType = 'd' - ReplicaIdentityFull = 'f' - ReplicaIdentityIndex = 'i' - ReplicaIdentityNothing = 'n' + ReplicaIdentityFull ReplicaIdentityType = 'f' + ReplicaIdentityIndex ReplicaIdentityType = 'i' + ReplicaIdentityNothing ReplicaIdentityType = 'n' ) // getRelIDForTable returns the relation ID for a table. diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index 46f18aaa3f..f08b66a6c8 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -55,6 +55,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. } } +// nolint:unparam func generateRecords( t *testing.T, nullable bool, diff --git a/flow/connectors/utils/ssh.go b/flow/connectors/utils/ssh.go index 7bd8ed141f..511eea672a 100644 --- a/flow/connectors/utils/ssh.go +++ b/flow/connectors/utils/ssh.go @@ -41,8 +41,9 @@ func GetSSHClientConfig(user, password, privateKeyString string) (*ssh.ClientCon } return &ssh.ClientConfig{ - User: user, - Auth: authMethods, + User: user, + Auth: authMethods, + //nolint:gosec HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, nil } diff --git a/flow/e2e/snowflake/qrep_flow_sf_test.go b/flow/e2e/snowflake/qrep_flow_sf_test.go index 3ac7fee713..b3cd9b9c2a 100644 --- a/flow/e2e/snowflake/qrep_flow_sf_test.go +++ b/flow/e2e/snowflake/qrep_flow_sf_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" ) +// nolint:unparam func (s PeerFlowE2ETestSuiteSF) setupSourceTable(tableName string, numRows int) { err := e2e.CreateTableForQRep(s.pool, s.pgSuffix, tableName) require.NoError(s.t, err) From eb63a7685f85f4162430b9b33bbeffd3ed0de5f5 Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Tue, 26 Dec 2023 18:44:21 +0530 Subject: [PATCH 2/2] added capability for BQ CDC across datasets (#904) 1) Just like Snowflake and Postgres, now BigQuery takes tables in the form of `.`. If dataset is omitted then it defaults to using the dataset specified at the time of peer creation. 2) If the dataset doesn't exist at the time of mirror creation, it is created during `SetupNormalizedTables` before the tables in the dataset. 3) A check has also been added so that two source tables cannot point to the same destination table specified in 2 different formats. --- flow/connectors/bigquery/bigquery.go | 175 ++++++++++-------- .../bigquery/merge_statement_generator.go | 46 ++--- flow/connectors/bigquery/qrep.go | 10 +- flow/connectors/bigquery/qrep_avro_sync.go | 53 +++--- flow/connectors/eventhub/eventhub.go | 2 +- flow/connectors/postgres/cdc.go | 4 +- flow/e2e/bigquery/bigquery_helper.go | 23 ++- flow/e2e/bigquery/peer_flow_bq_test.go | 70 ++++++- flow/model/model.go | 12 +- 9 files changed, 248 insertions(+), 147 deletions(-) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 0a220ef424..3da34f99d7 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -29,9 +29,7 @@ import ( const ( /* Different batch Ids in code/BigQuery - 1. batchID - identifier in raw/staging tables on target to depict which batch a row was inserted. - 2. stagingBatchID - the random batch id we generate before ingesting into staging table. - helps filter rows in the current batch before inserting into raw table. + 1. batchID - identifier in raw table on target to depict which batch a row was inserted. 3. syncBatchID - batch id that was last synced or will be synced 4. normalizeBatchID - batch id that was last normalized or will be normalized. */ @@ -233,8 +231,8 @@ func (c *BigQueryConnector) InitializeTableSchema(req map[string]*protos.TableSc return nil } -func (c *BigQueryConnector) waitForTableReady(tblName string) error { - table := c.client.Dataset(c.datasetID).Table(tblName) +func (c *BigQueryConnector) waitForTableReady(datasetTable *datasetTable) error { + table := c.client.Dataset(datasetTable.dataset).Table(datasetTable.table) maxDuration := 5 * time.Minute deadline := time.Now().Add(maxDuration) sleepInterval := 5 * time.Second @@ -242,7 +240,7 @@ func (c *BigQueryConnector) waitForTableReady(tblName string) error { for { if time.Now().After(deadline) { - return fmt.Errorf("timeout reached while waiting for table %s to be ready", tblName) + return fmt.Errorf("timeout reached while waiting for table %s to be ready", datasetTable) } _, err := table.Metadata(c.ctx) @@ -250,7 +248,8 @@ func (c *BigQueryConnector) waitForTableReady(tblName string) error { return nil } - slog.Info("waiting for table to be ready", slog.String("table", tblName), slog.Int("attempt", attempt)) + slog.Info("waiting for table to be ready", + slog.String("table", datasetTable.table), slog.Int("attempt", attempt)) attempt++ time.Sleep(sleepInterval) } @@ -267,9 +266,10 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, } for _, addedColumn := range schemaDelta.AddedColumns { + dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName) _, err := c.client.Query(fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", c.datasetID, - schemaDelta.DstTableName, addedColumn.ColumnName, + "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset, + dstDatasetTable.table, addedColumn.ColumnName, qValueKindToBigQueryType(addedColumn.ColumnType))).Read(c.ctx) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, @@ -593,16 +593,11 @@ func (c *BigQueryConnector) syncRecordsViaAvro( var entries [10]qvalue.QValue switch r := record.(type) { case *model.InsertRecord: - itemsJSON, err := r.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: itemsJSON, @@ -626,16 +621,11 @@ func (c *BigQueryConnector) syncRecordsViaAvro( if err != nil { return nil, fmt.Errorf("failed to create new items to json: %v", err) } - oldItemsJSON, err := r.OldItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create old items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: newItemsJSON, @@ -660,10 +650,6 @@ func (c *BigQueryConnector) syncRecordsViaAvro( return nil, fmt.Errorf("failed to create items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: itemsJSON, @@ -698,6 +684,10 @@ func (c *BigQueryConnector) syncRecordsViaAvro( Kind: qvalue.QValueKindInt64, Value: time.Now().UnixNano(), } + entries[3] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: record.GetDestinationTableName(), + } entries[7] = qvalue.QValue{ Kind: qvalue.QValueKindInt64, Value: syncBatchID, @@ -787,14 +777,18 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.datasetID, rawTableName, distinctTableNames)) for _, tableName := range distinctTableNames { + dstDatasetTable, _ := c.convertToDatasetTable(tableName) mergeGen := &mergeStmtGenerator{ - Dataset: c.datasetID, - NormalizedTable: tableName, - RawTable: rawTableName, - NormalizedTableSchema: c.tableNameSchemaMapping[tableName], - SyncBatchID: syncBatchID, - NormalizeBatchID: normalizeBatchID, - UnchangedToastColumns: tableNametoUnchangedToastCols[tableName], + rawDatasetTable: &datasetTable{ + dataset: c.datasetID, + table: rawTableName, + }, + dstTableName: tableName, + dstDatasetTable: dstDatasetTable, + normalizedTableSchema: c.tableNameSchemaMapping[tableName], + syncBatchID: syncBatchID, + normalizeBatchID: normalizeBatchID, + unchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, SyncedAtColName: req.SyncedAtColName, @@ -846,19 +840,6 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr {Name: "_peerdb_unchanged_toast_columns", Type: bigquery.StringFieldType}, } - stagingSchema := bigquery.Schema{ - {Name: "_peerdb_uid", Type: bigquery.StringFieldType}, - {Name: "_peerdb_timestamp", Type: bigquery.TimestampFieldType}, - {Name: "_peerdb_timestamp_nanos", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_destination_table_name", Type: bigquery.StringFieldType}, - {Name: "_peerdb_data", Type: bigquery.StringFieldType}, - {Name: "_peerdb_record_type", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_match_data", Type: bigquery.StringFieldType}, - {Name: "_peerdb_batch_id", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_staging_batch_id", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_unchanged_toast_columns", Type: bigquery.StringFieldType}, - } - // create the table table := c.client.Dataset(c.datasetID).Table(rawTableName) @@ -883,16 +864,6 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr return nil, fmt.Errorf("failed to create table %s.%s: %w", c.datasetID, rawTableName, err) } - // also create a staging table for this raw table - stagingTableName := c.getStagingTableName(req.FlowJobName) - stagingTable := c.client.Dataset(c.datasetID).Table(stagingTableName) - err = stagingTable.Create(c.ctx, &bigquery.TableMetadata{ - Schema: stagingSchema, - }) - if err != nil { - return nil, fmt.Errorf("failed to create table %s.%s: %w", c.datasetID, stagingTableName, err) - } - return &protos.CreateRawTableOutput{ TableIdentifier: rawTableName, }, nil @@ -952,14 +923,41 @@ func (c *BigQueryConnector) SetupNormalizedTables( req *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) + datasetTablesSet := make(map[datasetTable]struct{}) for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - table := c.client.Dataset(c.datasetID).Table(tableIdentifier) + // only place where we check for parsing errors + datasetTable, err := c.convertToDatasetTable(tableIdentifier) + if err != nil { + return nil, err + } + _, ok := datasetTablesSet[*datasetTable] + if ok { + return nil, fmt.Errorf("invalid mirror: two tables mirror to the same BigQuery table %s", + datasetTable.string()) + } + dataset := c.client.Dataset(datasetTable.dataset) + _, err = dataset.Metadata(c.ctx) + // just assume this means dataset don't exist, and create it + if err != nil { + // if err message does not contain `notFound`, then other error happened. + if !strings.Contains(err.Error(), "notFound") { + return nil, fmt.Errorf("error while checking metadata for BigQuery dataset %s: %w", + datasetTable.dataset, err) + } + c.logger.InfoContext(c.ctx, fmt.Sprintf("creating dataset %s...", dataset.DatasetID)) + err = dataset.Create(c.ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to create BigQuery dataset %s: %w", dataset.DatasetID, err) + } + } + table := dataset.Table(datasetTable.table) // check if the table exists - _, err := table.Metadata(c.ctx) + _, err = table.Metadata(c.ctx) if err == nil { // table exists, go to next table tableExistsMapping[tableIdentifier] = true + datasetTablesSet[*datasetTable] = struct{}{} continue } @@ -999,6 +997,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( } tableExistsMapping[tableIdentifier] = false + datasetTablesSet[*datasetTable] = struct{}{} // log that table was created c.logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) } @@ -1015,10 +1014,6 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error { if err != nil { return fmt.Errorf("failed to delete raw table: %w", err) } - err = dataset.Table(c.getStagingTableName(jobName)).Delete(c.ctx) - if err != nil { - return fmt.Errorf("failed to delete staging table: %w", err) - } // deleting job from metadata table query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) @@ -1036,35 +1031,33 @@ func (c *BigQueryConnector) getRawTableName(flowJobName string) string { return fmt.Sprintf("_peerdb_raw_%s", flowJobName) } -// getStagingTableName returns the staging table name for the given table identifier. -func (c *BigQueryConnector) getStagingTableName(flowJobName string) string { - // replace all non-alphanumeric characters with _ - flowJobName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(flowJobName, "_") - return fmt.Sprintf("_peerdb_staging_%s", flowJobName) -} - func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { for _, renameRequest := range req.RenameTableOptions { - src := renameRequest.CurrentName - dst := renameRequest.NewName - c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) + srcDatasetTable, _ := c.convertToDatasetTable(renameRequest.CurrentName) + dstDatasetTable, _ := c.convertToDatasetTable(renameRequest.NewName) + c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), + dstDatasetTable.string())) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) + activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), + dstDatasetTable.string())) // drop the dst table if exists - _, err := c.client.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", c.datasetID, dst)).Run(c.ctx) + _, err := c.client.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", + dstDatasetTable.dataset, dstDatasetTable.table)).Run(c.ctx) if err != nil { - return nil, fmt.Errorf("unable to drop table %s: %w", dst, err) + return nil, fmt.Errorf("unable to drop table %s: %w", dstDatasetTable.string(), err) } // rename the src table to dst _, err = c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s RENAME TO %s", - c.datasetID, src, dst)).Run(c.ctx) + srcDatasetTable.dataset, srcDatasetTable.table, dstDatasetTable.table)).Run(c.ctx) if err != nil { - return nil, fmt.Errorf("unable to rename table %s to %s: %w", src, dst, err) + return nil, fmt.Errorf("unable to rename table %s to %s: %w", srcDatasetTable.string(), + dstDatasetTable.string(), err) } - c.logger.Info(fmt.Sprintf("successfully renamed table '%s' to '%s'", src, dst)) + c.logger.Info(fmt.Sprintf("successfully renamed table '%s' to '%s'", srcDatasetTable.string(), + dstDatasetTable.string())) } return &protos.RenameTablesOutput{ @@ -1076,13 +1069,15 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro *protos.CreateTablesFromExistingOutput, error, ) { for newTable, existingTable := range req.NewToExistingTableMapping { + newDatasetTable, _ := c.convertToDatasetTable(newTable) + existingDatasetTable, _ := c.convertToDatasetTable(existingTable) c.logger.Info(fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) activity.RecordHeartbeat(c.ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) // rename the src table to dst - _, err := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s LIKE %s.%s", - c.datasetID, newTable, c.datasetID, existingTable)).Run(c.ctx) + _, err := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` LIKE `%s`", + newDatasetTable.string(), existingDatasetTable.string())).Run(c.ctx) if err != nil { return nil, fmt.Errorf("unable to create table %s: %w", newTable, err) } @@ -1094,3 +1089,29 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro FlowJobName: req.FlowJobName, }, nil } + +type datasetTable struct { + dataset string + table string +} + +func (d *datasetTable) string() string { + return fmt.Sprintf("%s.%s", d.dataset, d.table) +} + +func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) { + parts := strings.Split(tableName, ".") + if len(parts) == 1 { + return &datasetTable{ + dataset: c.datasetID, + table: parts[0], + }, nil + } else if len(parts) == 2 { + return &datasetTable{ + dataset: parts[0], + table: parts[1], + }, nil + } else { + return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName) + } +} diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_statement_generator.go index 22f876b8c3..e9a71b06cd 100644 --- a/flow/connectors/bigquery/merge_statement_generator.go +++ b/flow/connectors/bigquery/merge_statement_generator.go @@ -11,20 +11,20 @@ import ( ) type mergeStmtGenerator struct { - // dataset of all the tables - Dataset string - // the table to merge into - NormalizedTable string - // the table where the data is currently staged. - RawTable string + // dataset + raw table + rawDatasetTable *datasetTable + // destination table name, used to retrieve records from raw table + dstTableName string + // dataset + destination table + dstDatasetTable *datasetTable // last synced batchID. - SyncBatchID int64 + syncBatchID int64 // last normalized batchID. - NormalizeBatchID int64 + normalizeBatchID int64 // the schema of the table to merge into - NormalizedTableSchema *protos.TableSchema + normalizedTableSchema *protos.TableSchema // array of toast column combinations that are unchanged - UnchangedToastColumns []string + unchangedToastColumns []string // _PEERDB_IS_DELETED and _SYNCED_AT columns peerdbCols *protos.PeerDBColumns } @@ -34,7 +34,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR // statement. flattenedProjs := make([]string, 0) - for colName, colType := range m.NormalizedTableSchema.Columns { + for colName, colType := range m.normalizedTableSchema.Columns { bqType := qValueKindToBigQueryType(colType) // CAST doesn't work for FLOAT, so rewrite it to FLOAT64. if bqType == bigquery.FloatFieldType { @@ -87,10 +87,10 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { // normalize anything between last normalized batch id to last sync batchid return fmt.Sprintf(`WITH _peerdb_flattened AS - (SELECT %s FROM %s.%s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and + (SELECT %s FROM %s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and _peerdb_destination_table_name='%s')`, - strings.Join(flattenedProjs, ", "), m.Dataset, m.RawTable, m.NormalizeBatchID, - m.SyncBatchID, m.NormalizedTable) + strings.Join(flattenedProjs, ", "), m.rawDatasetTable.string(), m.normalizeBatchID, + m.syncBatchID, m.dstTableName) } // generateDeDupedCTE generates a de-duped CTE. @@ -104,7 +104,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { ) _peerdb_ranked WHERE _peerdb_rank = 1 ) SELECT * FROM _peerdb_de_duplicated_data_res` - pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.NormalizedTableSchema.PrimaryKeyColumns, + pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.normalizedTableSchema.PrimaryKeyColumns, ", '_peerdb_concat_', ")) return fmt.Sprintf(cte, pkeyColsStr) } @@ -112,9 +112,9 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { // generateMergeStmt generates a merge statement. func (m *mergeStmtGenerator) generateMergeStmt() string { // comma separated list of column names - backtickColNames := make([]string, 0, len(m.NormalizedTableSchema.Columns)) - pureColNames := make([]string, 0, len(m.NormalizedTableSchema.Columns)) - for colName := range m.NormalizedTableSchema.Columns { + backtickColNames := make([]string, 0, len(m.normalizedTableSchema.Columns)) + pureColNames := make([]string, 0, len(m.normalizedTableSchema.Columns)) + for colName := range m.normalizedTableSchema.Columns { backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName)) pureColNames = append(pureColNames, colName) } @@ -123,7 +123,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { insertValuesSQL := csep + ",CURRENT_TIMESTAMP" updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, - m.UnchangedToastColumns, m.peerdbCols) + m.unchangedToastColumns, m.peerdbCols) if m.peerdbCols.SoftDelete { softDeleteInsertColumnsSQL := insertColumnsSQL + fmt.Sprintf(", `%s`", m.peerdbCols.SoftDeleteColName) softDeleteInsertValuesSQL := insertValuesSQL + ", TRUE" @@ -134,8 +134,8 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { } updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - pkeySelectSQLArray := make([]string, 0, len(m.NormalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range m.NormalizedTableSchema.PrimaryKeyColumns { + pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("_peerdb_target.%s = _peerdb_deduped.%s", pkeyColName, pkeyColName)) } @@ -153,14 +153,14 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { } return fmt.Sprintf(` - MERGE %s.%s _peerdb_target USING (%s,%s) _peerdb_deduped + MERGE %s _peerdb_target USING (%s,%s) _peerdb_deduped ON %s WHEN NOT MATCHED and (_peerdb_deduped._peerdb_record_type != 2) THEN INSERT (%s) VALUES (%s) %s WHEN MATCHED AND (_peerdb_deduped._peerdb_record_type = 2) THEN %s; - `, m.Dataset, m.NormalizedTable, m.generateFlattenedCTE(), m.generateDeDupedCTE(), + `, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(), pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) } diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index df771e50a2..305bab01eb 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -45,7 +45,7 @@ func (c *BigQueryConnector) SyncQRepRecords( " partition %s of destination table %s", partition.PartitionId, destTable)) - avroSync := &QRepAvroSyncMethod{connector: c, gcsBucket: config.StagingPath} + avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName) return avroSync.SyncQRepRecords(config.FlowJobName, destTable, partition, tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } @@ -53,11 +53,11 @@ func (c *BigQueryConnector) SyncQRepRecords( func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfig, partition *protos.QRepPartition, srcSchema *model.QRecordSchema, ) (*bigquery.TableMetadata, error) { - destTable := config.DestinationTableIdentifier - bqTable := c.client.Dataset(c.datasetID).Table(destTable) + destDatasetTable, _ := c.convertToDatasetTable(config.DestinationTableIdentifier) + bqTable := c.client.Dataset(destDatasetTable.dataset).Table(destDatasetTable.table) dstTableMetadata, err := bqTable.Metadata(c.ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata of table %s: %w", destTable, err) + return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } tableSchemaDelta := &protos.TableSchemaDelta{ @@ -92,7 +92,7 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfi } dstTableMetadata, err = bqTable.Metadata(c.ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata of table %s: %w", destTable, err) + return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } return dstTableMetadata, nil } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 7ed87b0c06..8e600d5279 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -35,7 +35,7 @@ func NewQRepAvroSyncMethod(connector *BigQueryConnector, gcsBucket string, } func (s *QRepAvroSyncMethod) SyncRecords( - dstTableName string, + rawTableName string, flowJobName string, lastCP int64, dstTableMetadata *bigquery.TableMetadata, @@ -45,16 +45,20 @@ func (s *QRepAvroSyncMethod) SyncRecords( activity.RecordHeartbeat(s.connector.ctx, time.Minute, fmt.Sprintf("Flow job %s: Obtaining Avro schema"+ " for destination table %s and sync batch ID %d", - flowJobName, dstTableName, syncBatchID), + flowJobName, rawTableName, syncBatchID), ) // You will need to define your Avro schema as a string - avroSchema, err := DefineAvroSchema(dstTableName, dstTableMetadata, "", "") + avroSchema, err := DefineAvroSchema(rawTableName, dstTableMetadata, "", "") if err != nil { return 0, fmt.Errorf("failed to define Avro schema: %w", err) } - stagingTable := fmt.Sprintf("%s_%s_staging", dstTableName, fmt.Sprint(syncBatchID)) - numRecords, err := s.writeToStage(fmt.Sprint(syncBatchID), dstTableName, avroSchema, stagingTable, stream) + stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, fmt.Sprint(syncBatchID)) + numRecords, err := s.writeToStage(fmt.Sprint(syncBatchID), rawTableName, avroSchema, + &datasetTable{ + dataset: s.connector.datasetID, + table: stagingTable, + }, stream) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %v", err) } @@ -62,7 +66,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( bqClient := s.connector.client datasetID := s.connector.datasetID insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;", - datasetID, dstTableName, datasetID, stagingTable) + datasetID, rawTableName, datasetID, stagingTable) updateMetadataStmt, err := s.connector.getUpdateMetadataStmt(flowJobName, lastCP, syncBatchID) if err != nil { return -1, fmt.Errorf("failed to update metadata: %v", err) @@ -71,7 +75,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( activity.RecordHeartbeat(s.connector.ctx, time.Minute, fmt.Sprintf("Flow job %s: performing insert and update transaction"+ " for destination table %s and sync batch ID %d", - flowJobName, dstTableName, syncBatchID), + flowJobName, rawTableName, syncBatchID), ) stmts := []string{ @@ -91,12 +95,12 @@ func (s *QRepAvroSyncMethod) SyncRecords( slog.Error("failed to delete staging table "+stagingTable, slog.Any("error", err), slog.String("syncBatchID", fmt.Sprint(syncBatchID)), - slog.String("destinationTable", dstTableName)) + slog.String("destinationTable", rawTableName)) } - slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, dstTableName), + slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, rawTableName), slog.String(string(shared.FlowNameKey), flowJobName), - slog.String("dstTableName", dstTableName)) + slog.String("dstTableName", rawTableName)) return numRecords, nil } @@ -124,8 +128,14 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( slog.Info("Obtained Avro schema for destination table", flowLog) slog.Info(fmt.Sprintf("Avro schema: %v\n", avroSchema), flowLog) // create a staging table name with partitionID replace hyphens with underscores - stagingTable := fmt.Sprintf("%s_%s_staging", dstTableName, strings.ReplaceAll(partition.PartitionId, "-", "_")) - numRecords, err := s.writeToStage(partition.PartitionId, flowJobName, avroSchema, stagingTable, stream) + dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName) + stagingDatasetTable := &datasetTable{ + dataset: dstDatasetTable.dataset, + table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, + strings.ReplaceAll(partition.PartitionId, "-", "_")), + } + numRecords, err := s.writeToStage(partition.PartitionId, flowJobName, avroSchema, + stagingDatasetTable, stream) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %v", err) } @@ -135,7 +145,6 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( flowJobName, dstTableName, partition.PartitionId), ) bqClient := s.connector.client - datasetID := s.connector.datasetID selector := "*" if softDeleteCol != "" { // PeerDB column @@ -145,8 +154,8 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( selector += ", CURRENT_TIMESTAMP" } // Insert the records from the staging table into the destination table - insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT %s FROM `%s.%s`;", - datasetID, dstTableName, selector, datasetID, stagingTable) + insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;", + dstDatasetTable.string(), selector, stagingDatasetTable.string()) insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) if err != nil { @@ -166,14 +175,15 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( } // drop the staging table - if err := bqClient.Dataset(datasetID).Table(stagingTable).Delete(s.connector.ctx); err != nil { + if err := bqClient.Dataset(stagingDatasetTable.dataset). + Table(stagingDatasetTable.table).Delete(s.connector.ctx); err != nil { // just log the error this isn't fatal. - slog.Error("failed to delete staging table "+stagingTable, + slog.Error("failed to delete staging table "+stagingDatasetTable.string(), slog.Any("error", err), flowLog) } - slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, dstTableName), flowLog) + slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog) return numRecords, nil } @@ -323,7 +333,7 @@ func (s *QRepAvroSyncMethod) writeToStage( syncID string, objectFolder string, avroSchema *model.QRecordAvroSchemaDefinition, - stagingTable string, + stagingTable *datasetTable, stream *model.QRecordStream, ) (int, error) { shutdown := utils.HeartbeatRoutine(s.connector.ctx, time.Minute, @@ -379,7 +389,6 @@ func (s *QRepAvroSyncMethod) writeToStage( slog.Info(fmt.Sprintf("wrote %d records", avroFile.NumRecords), idLog) bqClient := s.connector.client - datasetID := s.connector.datasetID var avroRef bigquery.LoadSource if s.gcsBucket != "" { gcsRef := bigquery.NewGCSReference(fmt.Sprintf("gs://%s/%s", s.gcsBucket, avroFile.FilePath)) @@ -396,7 +405,7 @@ func (s *QRepAvroSyncMethod) writeToStage( avroRef = localRef } - loader := bqClient.Dataset(datasetID).Table(stagingTable).LoaderFrom(avroRef) + loader := bqClient.Dataset(stagingTable.dataset).Table(stagingTable.table).LoaderFrom(avroRef) loader.UseAvroLogicalTypes = true loader.WriteDisposition = bigquery.WriteTruncate job, err := loader.Run(s.connector.ctx) @@ -412,7 +421,7 @@ func (s *QRepAvroSyncMethod) writeToStage( if err := status.Err(); err != nil { return 0, fmt.Errorf("failed to load Avro file into BigQuery table: %w", err) } - slog.Info(fmt.Sprintf("Pushed into %s/%s", avroFile.FilePath, syncID)) + slog.Info(fmt.Sprintf("Pushed into %s", avroFile.FilePath)) err = s.connector.waitForTableReady(stagingTable) if err != nil { diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 05347a4263..c8ba3dad41 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -164,7 +164,7 @@ func (c *EventHubConnector) processBatch( return 0, err } - topicName, err := NewScopedEventhub(record.GetTableName()) + topicName, err := NewScopedEventhub(record.GetDestinationTableName()) if err != nil { c.logger.Error("failed to get topic name", slog.Any("error", err)) return 0, err diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 2be3fcb2a5..4c5693f292 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -365,7 +365,7 @@ func (p *PostgresCDCSource) consumeStream( } if rec != nil { - tableName := rec.GetTableName() + tableName := rec.GetDestinationTableName() switch r := rec.(type) { case *model.UpdateRecord: // tableName here is destination tableName. @@ -843,7 +843,7 @@ func (p *PostgresCDCSource) processRelationMessage( func (p *PostgresCDCSource) recToTablePKey(req *model.PullRecordsRequest, rec model.Record, ) (*model.TableWithPkey, error) { - tableName := rec.GetTableName() + tableName := rec.GetDestinationTableName() pkeyColsMerged := make([]byte, 0) for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns { diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index fb9dadb9ba..21bd3b5c75 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -94,12 +94,11 @@ func generateBQPeer(bigQueryConfig *protos.BigqueryConfig) *protos.Peer { } // datasetExists checks if the dataset exists. -func (b *BigQueryTestHelper) datasetExists() (bool, error) { - dataset := b.client.Dataset(b.Config.DatasetId) +func (b *BigQueryTestHelper) datasetExists(datasetName string) (bool, error) { + dataset := b.client.Dataset(datasetName) meta, err := dataset.Metadata(context.Background()) if err != nil { // if err message contains `notFound` then dataset does not exist. - // first we cast the error to a bigquery.Error if strings.Contains(err.Error(), "notFound") { fmt.Printf("dataset %s does not exist\n", b.Config.DatasetId) return false, nil @@ -117,12 +116,12 @@ func (b *BigQueryTestHelper) datasetExists() (bool, error) { // RecreateDataset recreates the dataset, i.e, deletes it if exists and creates it again. func (b *BigQueryTestHelper) RecreateDataset() error { - exists, err := b.datasetExists() + exists, err := b.datasetExists(b.datasetName) if err != nil { return fmt.Errorf("failed to check if dataset %s exists: %w", b.Config.DatasetId, err) } - dataset := b.client.Dataset(b.Config.DatasetId) + dataset := b.client.Dataset(b.datasetName) if exists { err := dataset.DeleteWithContents(context.Background()) if err != nil { @@ -135,13 +134,13 @@ func (b *BigQueryTestHelper) RecreateDataset() error { return fmt.Errorf("failed to create dataset: %w", err) } - fmt.Printf("created dataset %s successfully\n", b.Config.DatasetId) + fmt.Printf("created dataset %s successfully\n", b.datasetName) return nil } // DropDataset drops the dataset. -func (b *BigQueryTestHelper) DropDataset() error { - exists, err := b.datasetExists() +func (b *BigQueryTestHelper) DropDataset(datasetName string) error { + exists, err := b.datasetExists(datasetName) if err != nil { return fmt.Errorf("failed to check if dataset %s exists: %w", b.Config.DatasetId, err) } @@ -150,7 +149,7 @@ func (b *BigQueryTestHelper) DropDataset() error { return nil } - dataset := b.client.Dataset(b.Config.DatasetId) + dataset := b.client.Dataset(datasetName) err = dataset.DeleteWithContents(context.Background()) if err != nil { return fmt.Errorf("failed to delete dataset: %w", err) @@ -171,7 +170,11 @@ func (b *BigQueryTestHelper) RunCommand(command string) error { // countRows(tableName) returns the number of rows in the given table. func (b *BigQueryTestHelper) countRows(tableName string) (int, error) { - command := fmt.Sprintf("SELECT COUNT(*) FROM `%s.%s`", b.Config.DatasetId, tableName) + return b.countRowsWithDataset(b.datasetName, tableName) +} + +func (b *BigQueryTestHelper) countRowsWithDataset(dataset, tableName string) (int, error) { + command := fmt.Sprintf("SELECT COUNT(*) FROM `%s.%s`", dataset, tableName) it, err := b.client.Query(command).Read(context.Background()) if err != nil { return 0, fmt.Errorf("failed to run command: %w", err) diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index b28577f4d3..c76688f79b 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -150,7 +150,7 @@ func (s PeerFlowE2ETestSuiteBQ) tearDownSuite() { s.FailNow() } - err = s.bqHelper.DropDataset() + err = s.bqHelper.DropDataset(s.bqHelper.datasetName) if err != nil { slog.Error("failed to tear down bigquery", slog.Any("error", err)) s.FailNow() @@ -1203,3 +1203,71 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Columns_BQ() { env.AssertExpectations(s.t) } + +func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_Multi_Dataset_BQ() { + env := e2e.NewTemporalTestWorkflowEnvironment() + e2e.RegisterWorkflowsAndActivities(env, s.t) + + srcTable1Name := s.attachSchemaSuffix("test1_bq") + dstTable1Name := "test1_bq" + secondDataset := fmt.Sprintf("%s_2", s.bqHelper.datasetName) + srcTable2Name := s.attachSchemaSuffix("test2_bq") + dstTable2Name := "test2_bq" + + _, err := s.pool.Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE %s(id serial primary key, c1 int, c2 text); + CREATE TABLE %s(id serial primary key, c1 int, c2 text); + `, srcTable1Name, srcTable2Name)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_multi_table_multi_dataset_bq"), + TableNameMapping: map[string]string{ + srcTable1Name: dstTable1Name, + srcTable2Name: fmt.Sprintf("%s.%s", secondDataset, dstTable2Name), + }, + PostgresPort: e2e.PostgresPort, + Destination: s.bqHelper.Peer, + CdcStagingPath: "", + } + + flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() + require.NoError(s.t, err) + + limits := peerflow.CDCFlowLimits{ + ExitAfterRecords: 2, + MaxBatchSize: 100, + } + + // in a separate goroutine, wait for PeerFlowStatusQuery to finish setup + // and execute a transaction touching toast columns + go func() { + e2e.SetupCDCFlowStatusQuery(env, connectionGen) + /* inserting across multiple tables*/ + _, err = s.pool.Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s (c1,c2) VALUES (1,'dummy_1'); + INSERT INTO %s (c1,c2) VALUES (-1,'dummy_-1'); + `, srcTable1Name, srcTable2Name)) + require.NoError(s.t, err) + fmt.Println("Executed an insert on two tables") + }() + + env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil) + + // Verify workflow completes without error + require.True(s.t, env.IsWorkflowCompleted()) + err = env.GetWorkflowError() + + count1, err := s.bqHelper.countRows(dstTable1Name) + require.NoError(s.t, err) + count2, err := s.bqHelper.countRowsWithDataset(secondDataset, dstTable2Name) + require.NoError(s.t, err) + + s.Equal(1, count1) + s.Equal(1, count2) + + err = s.bqHelper.DropDataset(secondDataset) + require.NoError(s.t, err) + + env.AssertExpectations(s.t) +} diff --git a/flow/model/model.go b/flow/model/model.go index 581b57178b..fc2c12d849 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -58,7 +58,7 @@ type Record interface { // GetCheckPointID returns the ID of the record. GetCheckPointID() int64 // get table name - GetTableName() string + GetDestinationTableName() string // get columns and values for the record GetItems() *RecordItems } @@ -244,7 +244,7 @@ func (r *InsertRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *InsertRecord) GetTableName() string { +func (r *InsertRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -273,7 +273,7 @@ func (r *UpdateRecord) GetCheckPointID() int64 { } // Implement Record interface for UpdateRecord. -func (r *UpdateRecord) GetTableName() string { +func (r *UpdateRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -299,7 +299,7 @@ func (r *DeleteRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *DeleteRecord) GetTableName() string { +func (r *DeleteRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -470,8 +470,8 @@ func (r *RelationRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *RelationRecord) GetTableName() string { - return r.TableSchemaDelta.SrcTableName +func (r *RelationRecord) GetDestinationTableName() string { + return r.TableSchemaDelta.DstTableName } func (r *RelationRecord) GetItems() *RecordItems {