diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 0a220ef424..3da34f99d7 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -29,9 +29,7 @@ import ( const ( /* Different batch Ids in code/BigQuery - 1. batchID - identifier in raw/staging tables on target to depict which batch a row was inserted. - 2. stagingBatchID - the random batch id we generate before ingesting into staging table. - helps filter rows in the current batch before inserting into raw table. + 1. batchID - identifier in raw table on target to depict which batch a row was inserted. 3. syncBatchID - batch id that was last synced or will be synced 4. normalizeBatchID - batch id that was last normalized or will be normalized. */ @@ -233,8 +231,8 @@ func (c *BigQueryConnector) InitializeTableSchema(req map[string]*protos.TableSc return nil } -func (c *BigQueryConnector) waitForTableReady(tblName string) error { - table := c.client.Dataset(c.datasetID).Table(tblName) +func (c *BigQueryConnector) waitForTableReady(datasetTable *datasetTable) error { + table := c.client.Dataset(datasetTable.dataset).Table(datasetTable.table) maxDuration := 5 * time.Minute deadline := time.Now().Add(maxDuration) sleepInterval := 5 * time.Second @@ -242,7 +240,7 @@ func (c *BigQueryConnector) waitForTableReady(tblName string) error { for { if time.Now().After(deadline) { - return fmt.Errorf("timeout reached while waiting for table %s to be ready", tblName) + return fmt.Errorf("timeout reached while waiting for table %s to be ready", datasetTable) } _, err := table.Metadata(c.ctx) @@ -250,7 +248,8 @@ func (c *BigQueryConnector) waitForTableReady(tblName string) error { return nil } - slog.Info("waiting for table to be ready", slog.String("table", tblName), slog.Int("attempt", attempt)) + slog.Info("waiting for table to be ready", + slog.String("table", datasetTable.table), slog.Int("attempt", attempt)) attempt++ time.Sleep(sleepInterval) } @@ -267,9 +266,10 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, } for _, addedColumn := range schemaDelta.AddedColumns { + dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName) _, err := c.client.Query(fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", c.datasetID, - schemaDelta.DstTableName, addedColumn.ColumnName, + "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset, + dstDatasetTable.table, addedColumn.ColumnName, qValueKindToBigQueryType(addedColumn.ColumnType))).Read(c.ctx) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, @@ -593,16 +593,11 @@ func (c *BigQueryConnector) syncRecordsViaAvro( var entries [10]qvalue.QValue switch r := record.(type) { case *model.InsertRecord: - itemsJSON, err := r.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: itemsJSON, @@ -626,16 +621,11 @@ func (c *BigQueryConnector) syncRecordsViaAvro( if err != nil { return nil, fmt.Errorf("failed to create new items to json: %v", err) } - oldItemsJSON, err := r.OldItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create old items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: newItemsJSON, @@ -660,10 +650,6 @@ func (c *BigQueryConnector) syncRecordsViaAvro( return nil, fmt.Errorf("failed to create items to json: %v", err) } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: r.DestinationTableName, - } entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindString, Value: itemsJSON, @@ -698,6 +684,10 @@ func (c *BigQueryConnector) syncRecordsViaAvro( Kind: qvalue.QValueKindInt64, Value: time.Now().UnixNano(), } + entries[3] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: record.GetDestinationTableName(), + } entries[7] = qvalue.QValue{ Kind: qvalue.QValueKindInt64, Value: syncBatchID, @@ -787,14 +777,18 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.datasetID, rawTableName, distinctTableNames)) for _, tableName := range distinctTableNames { + dstDatasetTable, _ := c.convertToDatasetTable(tableName) mergeGen := &mergeStmtGenerator{ - Dataset: c.datasetID, - NormalizedTable: tableName, - RawTable: rawTableName, - NormalizedTableSchema: c.tableNameSchemaMapping[tableName], - SyncBatchID: syncBatchID, - NormalizeBatchID: normalizeBatchID, - UnchangedToastColumns: tableNametoUnchangedToastCols[tableName], + rawDatasetTable: &datasetTable{ + dataset: c.datasetID, + table: rawTableName, + }, + dstTableName: tableName, + dstDatasetTable: dstDatasetTable, + normalizedTableSchema: c.tableNameSchemaMapping[tableName], + syncBatchID: syncBatchID, + normalizeBatchID: normalizeBatchID, + unchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, SyncedAtColName: req.SyncedAtColName, @@ -846,19 +840,6 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr {Name: "_peerdb_unchanged_toast_columns", Type: bigquery.StringFieldType}, } - stagingSchema := bigquery.Schema{ - {Name: "_peerdb_uid", Type: bigquery.StringFieldType}, - {Name: "_peerdb_timestamp", Type: bigquery.TimestampFieldType}, - {Name: "_peerdb_timestamp_nanos", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_destination_table_name", Type: bigquery.StringFieldType}, - {Name: "_peerdb_data", Type: bigquery.StringFieldType}, - {Name: "_peerdb_record_type", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_match_data", Type: bigquery.StringFieldType}, - {Name: "_peerdb_batch_id", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_staging_batch_id", Type: bigquery.IntegerFieldType}, - {Name: "_peerdb_unchanged_toast_columns", Type: bigquery.StringFieldType}, - } - // create the table table := c.client.Dataset(c.datasetID).Table(rawTableName) @@ -883,16 +864,6 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr return nil, fmt.Errorf("failed to create table %s.%s: %w", c.datasetID, rawTableName, err) } - // also create a staging table for this raw table - stagingTableName := c.getStagingTableName(req.FlowJobName) - stagingTable := c.client.Dataset(c.datasetID).Table(stagingTableName) - err = stagingTable.Create(c.ctx, &bigquery.TableMetadata{ - Schema: stagingSchema, - }) - if err != nil { - return nil, fmt.Errorf("failed to create table %s.%s: %w", c.datasetID, stagingTableName, err) - } - return &protos.CreateRawTableOutput{ TableIdentifier: rawTableName, }, nil @@ -952,14 +923,41 @@ func (c *BigQueryConnector) SetupNormalizedTables( req *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) + datasetTablesSet := make(map[datasetTable]struct{}) for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - table := c.client.Dataset(c.datasetID).Table(tableIdentifier) + // only place where we check for parsing errors + datasetTable, err := c.convertToDatasetTable(tableIdentifier) + if err != nil { + return nil, err + } + _, ok := datasetTablesSet[*datasetTable] + if ok { + return nil, fmt.Errorf("invalid mirror: two tables mirror to the same BigQuery table %s", + datasetTable.string()) + } + dataset := c.client.Dataset(datasetTable.dataset) + _, err = dataset.Metadata(c.ctx) + // just assume this means dataset don't exist, and create it + if err != nil { + // if err message does not contain `notFound`, then other error happened. + if !strings.Contains(err.Error(), "notFound") { + return nil, fmt.Errorf("error while checking metadata for BigQuery dataset %s: %w", + datasetTable.dataset, err) + } + c.logger.InfoContext(c.ctx, fmt.Sprintf("creating dataset %s...", dataset.DatasetID)) + err = dataset.Create(c.ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to create BigQuery dataset %s: %w", dataset.DatasetID, err) + } + } + table := dataset.Table(datasetTable.table) // check if the table exists - _, err := table.Metadata(c.ctx) + _, err = table.Metadata(c.ctx) if err == nil { // table exists, go to next table tableExistsMapping[tableIdentifier] = true + datasetTablesSet[*datasetTable] = struct{}{} continue } @@ -999,6 +997,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( } tableExistsMapping[tableIdentifier] = false + datasetTablesSet[*datasetTable] = struct{}{} // log that table was created c.logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) } @@ -1015,10 +1014,6 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error { if err != nil { return fmt.Errorf("failed to delete raw table: %w", err) } - err = dataset.Table(c.getStagingTableName(jobName)).Delete(c.ctx) - if err != nil { - return fmt.Errorf("failed to delete staging table: %w", err) - } // deleting job from metadata table query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) @@ -1036,35 +1031,33 @@ func (c *BigQueryConnector) getRawTableName(flowJobName string) string { return fmt.Sprintf("_peerdb_raw_%s", flowJobName) } -// getStagingTableName returns the staging table name for the given table identifier. -func (c *BigQueryConnector) getStagingTableName(flowJobName string) string { - // replace all non-alphanumeric characters with _ - flowJobName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(flowJobName, "_") - return fmt.Sprintf("_peerdb_staging_%s", flowJobName) -} - func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { for _, renameRequest := range req.RenameTableOptions { - src := renameRequest.CurrentName - dst := renameRequest.NewName - c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) + srcDatasetTable, _ := c.convertToDatasetTable(renameRequest.CurrentName) + dstDatasetTable, _ := c.convertToDatasetTable(renameRequest.NewName) + c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), + dstDatasetTable.string())) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) + activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), + dstDatasetTable.string())) // drop the dst table if exists - _, err := c.client.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", c.datasetID, dst)).Run(c.ctx) + _, err := c.client.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", + dstDatasetTable.dataset, dstDatasetTable.table)).Run(c.ctx) if err != nil { - return nil, fmt.Errorf("unable to drop table %s: %w", dst, err) + return nil, fmt.Errorf("unable to drop table %s: %w", dstDatasetTable.string(), err) } // rename the src table to dst _, err = c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s RENAME TO %s", - c.datasetID, src, dst)).Run(c.ctx) + srcDatasetTable.dataset, srcDatasetTable.table, dstDatasetTable.table)).Run(c.ctx) if err != nil { - return nil, fmt.Errorf("unable to rename table %s to %s: %w", src, dst, err) + return nil, fmt.Errorf("unable to rename table %s to %s: %w", srcDatasetTable.string(), + dstDatasetTable.string(), err) } - c.logger.Info(fmt.Sprintf("successfully renamed table '%s' to '%s'", src, dst)) + c.logger.Info(fmt.Sprintf("successfully renamed table '%s' to '%s'", srcDatasetTable.string(), + dstDatasetTable.string())) } return &protos.RenameTablesOutput{ @@ -1076,13 +1069,15 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro *protos.CreateTablesFromExistingOutput, error, ) { for newTable, existingTable := range req.NewToExistingTableMapping { + newDatasetTable, _ := c.convertToDatasetTable(newTable) + existingDatasetTable, _ := c.convertToDatasetTable(existingTable) c.logger.Info(fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) activity.RecordHeartbeat(c.ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) // rename the src table to dst - _, err := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s LIKE %s.%s", - c.datasetID, newTable, c.datasetID, existingTable)).Run(c.ctx) + _, err := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` LIKE `%s`", + newDatasetTable.string(), existingDatasetTable.string())).Run(c.ctx) if err != nil { return nil, fmt.Errorf("unable to create table %s: %w", newTable, err) } @@ -1094,3 +1089,29 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro FlowJobName: req.FlowJobName, }, nil } + +type datasetTable struct { + dataset string + table string +} + +func (d *datasetTable) string() string { + return fmt.Sprintf("%s.%s", d.dataset, d.table) +} + +func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) { + parts := strings.Split(tableName, ".") + if len(parts) == 1 { + return &datasetTable{ + dataset: c.datasetID, + table: parts[0], + }, nil + } else if len(parts) == 2 { + return &datasetTable{ + dataset: parts[0], + table: parts[1], + }, nil + } else { + return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName) + } +} diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_statement_generator.go index 22f876b8c3..e9a71b06cd 100644 --- a/flow/connectors/bigquery/merge_statement_generator.go +++ b/flow/connectors/bigquery/merge_statement_generator.go @@ -11,20 +11,20 @@ import ( ) type mergeStmtGenerator struct { - // dataset of all the tables - Dataset string - // the table to merge into - NormalizedTable string - // the table where the data is currently staged. - RawTable string + // dataset + raw table + rawDatasetTable *datasetTable + // destination table name, used to retrieve records from raw table + dstTableName string + // dataset + destination table + dstDatasetTable *datasetTable // last synced batchID. - SyncBatchID int64 + syncBatchID int64 // last normalized batchID. - NormalizeBatchID int64 + normalizeBatchID int64 // the schema of the table to merge into - NormalizedTableSchema *protos.TableSchema + normalizedTableSchema *protos.TableSchema // array of toast column combinations that are unchanged - UnchangedToastColumns []string + unchangedToastColumns []string // _PEERDB_IS_DELETED and _SYNCED_AT columns peerdbCols *protos.PeerDBColumns } @@ -34,7 +34,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR // statement. flattenedProjs := make([]string, 0) - for colName, colType := range m.NormalizedTableSchema.Columns { + for colName, colType := range m.normalizedTableSchema.Columns { bqType := qValueKindToBigQueryType(colType) // CAST doesn't work for FLOAT, so rewrite it to FLOAT64. if bqType == bigquery.FloatFieldType { @@ -87,10 +87,10 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { // normalize anything between last normalized batch id to last sync batchid return fmt.Sprintf(`WITH _peerdb_flattened AS - (SELECT %s FROM %s.%s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and + (SELECT %s FROM %s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and _peerdb_destination_table_name='%s')`, - strings.Join(flattenedProjs, ", "), m.Dataset, m.RawTable, m.NormalizeBatchID, - m.SyncBatchID, m.NormalizedTable) + strings.Join(flattenedProjs, ", "), m.rawDatasetTable.string(), m.normalizeBatchID, + m.syncBatchID, m.dstTableName) } // generateDeDupedCTE generates a de-duped CTE. @@ -104,7 +104,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { ) _peerdb_ranked WHERE _peerdb_rank = 1 ) SELECT * FROM _peerdb_de_duplicated_data_res` - pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.NormalizedTableSchema.PrimaryKeyColumns, + pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.normalizedTableSchema.PrimaryKeyColumns, ", '_peerdb_concat_', ")) return fmt.Sprintf(cte, pkeyColsStr) } @@ -112,9 +112,9 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { // generateMergeStmt generates a merge statement. func (m *mergeStmtGenerator) generateMergeStmt() string { // comma separated list of column names - backtickColNames := make([]string, 0, len(m.NormalizedTableSchema.Columns)) - pureColNames := make([]string, 0, len(m.NormalizedTableSchema.Columns)) - for colName := range m.NormalizedTableSchema.Columns { + backtickColNames := make([]string, 0, len(m.normalizedTableSchema.Columns)) + pureColNames := make([]string, 0, len(m.normalizedTableSchema.Columns)) + for colName := range m.normalizedTableSchema.Columns { backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName)) pureColNames = append(pureColNames, colName) } @@ -123,7 +123,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { insertValuesSQL := csep + ",CURRENT_TIMESTAMP" updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, - m.UnchangedToastColumns, m.peerdbCols) + m.unchangedToastColumns, m.peerdbCols) if m.peerdbCols.SoftDelete { softDeleteInsertColumnsSQL := insertColumnsSQL + fmt.Sprintf(", `%s`", m.peerdbCols.SoftDeleteColName) softDeleteInsertValuesSQL := insertValuesSQL + ", TRUE" @@ -134,8 +134,8 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { } updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - pkeySelectSQLArray := make([]string, 0, len(m.NormalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range m.NormalizedTableSchema.PrimaryKeyColumns { + pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("_peerdb_target.%s = _peerdb_deduped.%s", pkeyColName, pkeyColName)) } @@ -153,14 +153,14 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { } return fmt.Sprintf(` - MERGE %s.%s _peerdb_target USING (%s,%s) _peerdb_deduped + MERGE %s _peerdb_target USING (%s,%s) _peerdb_deduped ON %s WHEN NOT MATCHED and (_peerdb_deduped._peerdb_record_type != 2) THEN INSERT (%s) VALUES (%s) %s WHEN MATCHED AND (_peerdb_deduped._peerdb_record_type = 2) THEN %s; - `, m.Dataset, m.NormalizedTable, m.generateFlattenedCTE(), m.generateDeDupedCTE(), + `, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(), pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) } diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index df771e50a2..305bab01eb 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -45,7 +45,7 @@ func (c *BigQueryConnector) SyncQRepRecords( " partition %s of destination table %s", partition.PartitionId, destTable)) - avroSync := &QRepAvroSyncMethod{connector: c, gcsBucket: config.StagingPath} + avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName) return avroSync.SyncQRepRecords(config.FlowJobName, destTable, partition, tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } @@ -53,11 +53,11 @@ func (c *BigQueryConnector) SyncQRepRecords( func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfig, partition *protos.QRepPartition, srcSchema *model.QRecordSchema, ) (*bigquery.TableMetadata, error) { - destTable := config.DestinationTableIdentifier - bqTable := c.client.Dataset(c.datasetID).Table(destTable) + destDatasetTable, _ := c.convertToDatasetTable(config.DestinationTableIdentifier) + bqTable := c.client.Dataset(destDatasetTable.dataset).Table(destDatasetTable.table) dstTableMetadata, err := bqTable.Metadata(c.ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata of table %s: %w", destTable, err) + return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } tableSchemaDelta := &protos.TableSchemaDelta{ @@ -92,7 +92,7 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfi } dstTableMetadata, err = bqTable.Metadata(c.ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata of table %s: %w", destTable, err) + return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } return dstTableMetadata, nil } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 7ed87b0c06..8e600d5279 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -35,7 +35,7 @@ func NewQRepAvroSyncMethod(connector *BigQueryConnector, gcsBucket string, } func (s *QRepAvroSyncMethod) SyncRecords( - dstTableName string, + rawTableName string, flowJobName string, lastCP int64, dstTableMetadata *bigquery.TableMetadata, @@ -45,16 +45,20 @@ func (s *QRepAvroSyncMethod) SyncRecords( activity.RecordHeartbeat(s.connector.ctx, time.Minute, fmt.Sprintf("Flow job %s: Obtaining Avro schema"+ " for destination table %s and sync batch ID %d", - flowJobName, dstTableName, syncBatchID), + flowJobName, rawTableName, syncBatchID), ) // You will need to define your Avro schema as a string - avroSchema, err := DefineAvroSchema(dstTableName, dstTableMetadata, "", "") + avroSchema, err := DefineAvroSchema(rawTableName, dstTableMetadata, "", "") if err != nil { return 0, fmt.Errorf("failed to define Avro schema: %w", err) } - stagingTable := fmt.Sprintf("%s_%s_staging", dstTableName, fmt.Sprint(syncBatchID)) - numRecords, err := s.writeToStage(fmt.Sprint(syncBatchID), dstTableName, avroSchema, stagingTable, stream) + stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, fmt.Sprint(syncBatchID)) + numRecords, err := s.writeToStage(fmt.Sprint(syncBatchID), rawTableName, avroSchema, + &datasetTable{ + dataset: s.connector.datasetID, + table: stagingTable, + }, stream) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %v", err) } @@ -62,7 +66,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( bqClient := s.connector.client datasetID := s.connector.datasetID insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;", - datasetID, dstTableName, datasetID, stagingTable) + datasetID, rawTableName, datasetID, stagingTable) updateMetadataStmt, err := s.connector.getUpdateMetadataStmt(flowJobName, lastCP, syncBatchID) if err != nil { return -1, fmt.Errorf("failed to update metadata: %v", err) @@ -71,7 +75,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( activity.RecordHeartbeat(s.connector.ctx, time.Minute, fmt.Sprintf("Flow job %s: performing insert and update transaction"+ " for destination table %s and sync batch ID %d", - flowJobName, dstTableName, syncBatchID), + flowJobName, rawTableName, syncBatchID), ) stmts := []string{ @@ -91,12 +95,12 @@ func (s *QRepAvroSyncMethod) SyncRecords( slog.Error("failed to delete staging table "+stagingTable, slog.Any("error", err), slog.String("syncBatchID", fmt.Sprint(syncBatchID)), - slog.String("destinationTable", dstTableName)) + slog.String("destinationTable", rawTableName)) } - slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, dstTableName), + slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, rawTableName), slog.String(string(shared.FlowNameKey), flowJobName), - slog.String("dstTableName", dstTableName)) + slog.String("dstTableName", rawTableName)) return numRecords, nil } @@ -124,8 +128,14 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( slog.Info("Obtained Avro schema for destination table", flowLog) slog.Info(fmt.Sprintf("Avro schema: %v\n", avroSchema), flowLog) // create a staging table name with partitionID replace hyphens with underscores - stagingTable := fmt.Sprintf("%s_%s_staging", dstTableName, strings.ReplaceAll(partition.PartitionId, "-", "_")) - numRecords, err := s.writeToStage(partition.PartitionId, flowJobName, avroSchema, stagingTable, stream) + dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName) + stagingDatasetTable := &datasetTable{ + dataset: dstDatasetTable.dataset, + table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, + strings.ReplaceAll(partition.PartitionId, "-", "_")), + } + numRecords, err := s.writeToStage(partition.PartitionId, flowJobName, avroSchema, + stagingDatasetTable, stream) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %v", err) } @@ -135,7 +145,6 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( flowJobName, dstTableName, partition.PartitionId), ) bqClient := s.connector.client - datasetID := s.connector.datasetID selector := "*" if softDeleteCol != "" { // PeerDB column @@ -145,8 +154,8 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( selector += ", CURRENT_TIMESTAMP" } // Insert the records from the staging table into the destination table - insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT %s FROM `%s.%s`;", - datasetID, dstTableName, selector, datasetID, stagingTable) + insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;", + dstDatasetTable.string(), selector, stagingDatasetTable.string()) insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) if err != nil { @@ -166,14 +175,15 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( } // drop the staging table - if err := bqClient.Dataset(datasetID).Table(stagingTable).Delete(s.connector.ctx); err != nil { + if err := bqClient.Dataset(stagingDatasetTable.dataset). + Table(stagingDatasetTable.table).Delete(s.connector.ctx); err != nil { // just log the error this isn't fatal. - slog.Error("failed to delete staging table "+stagingTable, + slog.Error("failed to delete staging table "+stagingDatasetTable.string(), slog.Any("error", err), flowLog) } - slog.Info(fmt.Sprintf("loaded stage into %s.%s", datasetID, dstTableName), flowLog) + slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog) return numRecords, nil } @@ -323,7 +333,7 @@ func (s *QRepAvroSyncMethod) writeToStage( syncID string, objectFolder string, avroSchema *model.QRecordAvroSchemaDefinition, - stagingTable string, + stagingTable *datasetTable, stream *model.QRecordStream, ) (int, error) { shutdown := utils.HeartbeatRoutine(s.connector.ctx, time.Minute, @@ -379,7 +389,6 @@ func (s *QRepAvroSyncMethod) writeToStage( slog.Info(fmt.Sprintf("wrote %d records", avroFile.NumRecords), idLog) bqClient := s.connector.client - datasetID := s.connector.datasetID var avroRef bigquery.LoadSource if s.gcsBucket != "" { gcsRef := bigquery.NewGCSReference(fmt.Sprintf("gs://%s/%s", s.gcsBucket, avroFile.FilePath)) @@ -396,7 +405,7 @@ func (s *QRepAvroSyncMethod) writeToStage( avroRef = localRef } - loader := bqClient.Dataset(datasetID).Table(stagingTable).LoaderFrom(avroRef) + loader := bqClient.Dataset(stagingTable.dataset).Table(stagingTable.table).LoaderFrom(avroRef) loader.UseAvroLogicalTypes = true loader.WriteDisposition = bigquery.WriteTruncate job, err := loader.Run(s.connector.ctx) @@ -412,7 +421,7 @@ func (s *QRepAvroSyncMethod) writeToStage( if err := status.Err(); err != nil { return 0, fmt.Errorf("failed to load Avro file into BigQuery table: %w", err) } - slog.Info(fmt.Sprintf("Pushed into %s/%s", avroFile.FilePath, syncID)) + slog.Info(fmt.Sprintf("Pushed into %s", avroFile.FilePath)) err = s.connector.waitForTableReady(stagingTable) if err != nil { diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 05347a4263..c8ba3dad41 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -164,7 +164,7 @@ func (c *EventHubConnector) processBatch( return 0, err } - topicName, err := NewScopedEventhub(record.GetTableName()) + topicName, err := NewScopedEventhub(record.GetDestinationTableName()) if err != nil { c.logger.Error("failed to get topic name", slog.Any("error", err)) return 0, err diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 2be3fcb2a5..4c5693f292 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -365,7 +365,7 @@ func (p *PostgresCDCSource) consumeStream( } if rec != nil { - tableName := rec.GetTableName() + tableName := rec.GetDestinationTableName() switch r := rec.(type) { case *model.UpdateRecord: // tableName here is destination tableName. @@ -843,7 +843,7 @@ func (p *PostgresCDCSource) processRelationMessage( func (p *PostgresCDCSource) recToTablePKey(req *model.PullRecordsRequest, rec model.Record, ) (*model.TableWithPkey, error) { - tableName := rec.GetTableName() + tableName := rec.GetDestinationTableName() pkeyColsMerged := make([]byte, 0) for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns { diff --git a/flow/e2e/bigquery/bigquery_helper.go b/flow/e2e/bigquery/bigquery_helper.go index fb9dadb9ba..21bd3b5c75 100644 --- a/flow/e2e/bigquery/bigquery_helper.go +++ b/flow/e2e/bigquery/bigquery_helper.go @@ -94,12 +94,11 @@ func generateBQPeer(bigQueryConfig *protos.BigqueryConfig) *protos.Peer { } // datasetExists checks if the dataset exists. -func (b *BigQueryTestHelper) datasetExists() (bool, error) { - dataset := b.client.Dataset(b.Config.DatasetId) +func (b *BigQueryTestHelper) datasetExists(datasetName string) (bool, error) { + dataset := b.client.Dataset(datasetName) meta, err := dataset.Metadata(context.Background()) if err != nil { // if err message contains `notFound` then dataset does not exist. - // first we cast the error to a bigquery.Error if strings.Contains(err.Error(), "notFound") { fmt.Printf("dataset %s does not exist\n", b.Config.DatasetId) return false, nil @@ -117,12 +116,12 @@ func (b *BigQueryTestHelper) datasetExists() (bool, error) { // RecreateDataset recreates the dataset, i.e, deletes it if exists and creates it again. func (b *BigQueryTestHelper) RecreateDataset() error { - exists, err := b.datasetExists() + exists, err := b.datasetExists(b.datasetName) if err != nil { return fmt.Errorf("failed to check if dataset %s exists: %w", b.Config.DatasetId, err) } - dataset := b.client.Dataset(b.Config.DatasetId) + dataset := b.client.Dataset(b.datasetName) if exists { err := dataset.DeleteWithContents(context.Background()) if err != nil { @@ -135,13 +134,13 @@ func (b *BigQueryTestHelper) RecreateDataset() error { return fmt.Errorf("failed to create dataset: %w", err) } - fmt.Printf("created dataset %s successfully\n", b.Config.DatasetId) + fmt.Printf("created dataset %s successfully\n", b.datasetName) return nil } // DropDataset drops the dataset. -func (b *BigQueryTestHelper) DropDataset() error { - exists, err := b.datasetExists() +func (b *BigQueryTestHelper) DropDataset(datasetName string) error { + exists, err := b.datasetExists(datasetName) if err != nil { return fmt.Errorf("failed to check if dataset %s exists: %w", b.Config.DatasetId, err) } @@ -150,7 +149,7 @@ func (b *BigQueryTestHelper) DropDataset() error { return nil } - dataset := b.client.Dataset(b.Config.DatasetId) + dataset := b.client.Dataset(datasetName) err = dataset.DeleteWithContents(context.Background()) if err != nil { return fmt.Errorf("failed to delete dataset: %w", err) @@ -171,7 +170,11 @@ func (b *BigQueryTestHelper) RunCommand(command string) error { // countRows(tableName) returns the number of rows in the given table. func (b *BigQueryTestHelper) countRows(tableName string) (int, error) { - command := fmt.Sprintf("SELECT COUNT(*) FROM `%s.%s`", b.Config.DatasetId, tableName) + return b.countRowsWithDataset(b.datasetName, tableName) +} + +func (b *BigQueryTestHelper) countRowsWithDataset(dataset, tableName string) (int, error) { + command := fmt.Sprintf("SELECT COUNT(*) FROM `%s.%s`", dataset, tableName) it, err := b.client.Query(command).Read(context.Background()) if err != nil { return 0, fmt.Errorf("failed to run command: %w", err) diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index b28577f4d3..c76688f79b 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -150,7 +150,7 @@ func (s PeerFlowE2ETestSuiteBQ) tearDownSuite() { s.FailNow() } - err = s.bqHelper.DropDataset() + err = s.bqHelper.DropDataset(s.bqHelper.datasetName) if err != nil { slog.Error("failed to tear down bigquery", slog.Any("error", err)) s.FailNow() @@ -1203,3 +1203,71 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Columns_BQ() { env.AssertExpectations(s.t) } + +func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_Multi_Dataset_BQ() { + env := e2e.NewTemporalTestWorkflowEnvironment() + e2e.RegisterWorkflowsAndActivities(env, s.t) + + srcTable1Name := s.attachSchemaSuffix("test1_bq") + dstTable1Name := "test1_bq" + secondDataset := fmt.Sprintf("%s_2", s.bqHelper.datasetName) + srcTable2Name := s.attachSchemaSuffix("test2_bq") + dstTable2Name := "test2_bq" + + _, err := s.pool.Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE %s(id serial primary key, c1 int, c2 text); + CREATE TABLE %s(id serial primary key, c1 int, c2 text); + `, srcTable1Name, srcTable2Name)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_multi_table_multi_dataset_bq"), + TableNameMapping: map[string]string{ + srcTable1Name: dstTable1Name, + srcTable2Name: fmt.Sprintf("%s.%s", secondDataset, dstTable2Name), + }, + PostgresPort: e2e.PostgresPort, + Destination: s.bqHelper.Peer, + CdcStagingPath: "", + } + + flowConnConfig, err := connectionGen.GenerateFlowConnectionConfigs() + require.NoError(s.t, err) + + limits := peerflow.CDCFlowLimits{ + ExitAfterRecords: 2, + MaxBatchSize: 100, + } + + // in a separate goroutine, wait for PeerFlowStatusQuery to finish setup + // and execute a transaction touching toast columns + go func() { + e2e.SetupCDCFlowStatusQuery(env, connectionGen) + /* inserting across multiple tables*/ + _, err = s.pool.Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s (c1,c2) VALUES (1,'dummy_1'); + INSERT INTO %s (c1,c2) VALUES (-1,'dummy_-1'); + `, srcTable1Name, srcTable2Name)) + require.NoError(s.t, err) + fmt.Println("Executed an insert on two tables") + }() + + env.ExecuteWorkflow(peerflow.CDCFlowWorkflowWithConfig, flowConnConfig, &limits, nil) + + // Verify workflow completes without error + require.True(s.t, env.IsWorkflowCompleted()) + err = env.GetWorkflowError() + + count1, err := s.bqHelper.countRows(dstTable1Name) + require.NoError(s.t, err) + count2, err := s.bqHelper.countRowsWithDataset(secondDataset, dstTable2Name) + require.NoError(s.t, err) + + s.Equal(1, count1) + s.Equal(1, count2) + + err = s.bqHelper.DropDataset(secondDataset) + require.NoError(s.t, err) + + env.AssertExpectations(s.t) +} diff --git a/flow/model/model.go b/flow/model/model.go index 581b57178b..fc2c12d849 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -58,7 +58,7 @@ type Record interface { // GetCheckPointID returns the ID of the record. GetCheckPointID() int64 // get table name - GetTableName() string + GetDestinationTableName() string // get columns and values for the record GetItems() *RecordItems } @@ -244,7 +244,7 @@ func (r *InsertRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *InsertRecord) GetTableName() string { +func (r *InsertRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -273,7 +273,7 @@ func (r *UpdateRecord) GetCheckPointID() int64 { } // Implement Record interface for UpdateRecord. -func (r *UpdateRecord) GetTableName() string { +func (r *UpdateRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -299,7 +299,7 @@ func (r *DeleteRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *DeleteRecord) GetTableName() string { +func (r *DeleteRecord) GetDestinationTableName() string { return r.DestinationTableName } @@ -470,8 +470,8 @@ func (r *RelationRecord) GetCheckPointID() int64 { return r.CheckPointID } -func (r *RelationRecord) GetTableName() string { - return r.TableSchemaDelta.SrcTableName +func (r *RelationRecord) GetDestinationTableName() string { + return r.TableSchemaDelta.DstTableName } func (r *RelationRecord) GetItems() *RecordItems {