From 5e18a79123883e3f641c52de21e4ce1c74aff65d Mon Sep 17 00:00:00 2001 From: Amogh-Bharadwaj Date: Thu, 25 Jan 2024 14:52:03 +0530 Subject: [PATCH] fix cross project support --- flow/connectors/bigquery/bigquery.go | 56 +++++++++++++--------- flow/connectors/bigquery/qrep.go | 8 ++-- flow/connectors/bigquery/qrep_avro_sync.go | 10 ++-- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 59498d927d..3775d91631 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -260,11 +260,11 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, for _, addedColumn := range schemaDelta.AddedColumns { dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName) query := c.client.Query(fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset, + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.table, addedColumn.ColumnName, qValueKindToBigQueryType(addedColumn.ColumnType))) query.DefaultProjectID = c.projectID - query.DefaultDatasetID = c.datasetID + query.DefaultDatasetID = dstDatasetTable.dataset _, err := query.Read(c.ctx) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, @@ -312,7 +312,7 @@ func (c *BigQueryConnector) SetupMetadataTables() error { } func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) { - query := fmt.Sprintf("SELECT offset FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT offset FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID @@ -339,8 +339,7 @@ func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) { func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) error { query := fmt.Sprintf( - "UPDATE %s.%s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'", - c.datasetID, + "UPDATE %s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'", MirrorJobsTable, lastOffset, jobName, @@ -357,8 +356,8 @@ func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) erro } func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { - query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(query) q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID @@ -384,8 +383,8 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(query) q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID @@ -422,9 +421,9 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn rawTableName := c.getRawTableName(flowJobName) // Prepare the query to retrieve distinct tables in that batch - query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s + query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d`, - c.datasetID, rawTableName, normalizeBatchID, syncBatchID) + rawTableName, normalizeBatchID, syncBatchID) // Run the query q := c.client.Query(query) q.DefaultProjectID = c.projectID @@ -465,10 +464,10 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync // where a placeholder value for unchanged cols can be set in DeleteRecord if there is no backfill // we don't want these particular DeleteRecords to be used in the update statement query := fmt.Sprintf(`SELECT _peerdb_destination_table_name, - array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s.%s + array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d AND _peerdb_record_type != 2 GROUP BY _peerdb_destination_table_name`, - c.datasetID, rawTableName, normalizeBatchID, syncBatchID) + rawTableName, normalizeBatchID, syncBatchID) // Run the query q := c.client.Query(query) q.DefaultDatasetID = c.datasetID @@ -593,6 +592,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) dstDatasetTable, _ := c.convertToDatasetTable(tableName) mergeGen := &mergeStmtGenerator{ rawDatasetTable: &datasetTable{ + project: c.projectID, dataset: c.datasetID, table: rawTableName, }, @@ -624,8 +624,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( - "UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - c.datasetID, MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) + "UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", + MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) query := c.client.Query(updateMetadataStmt) query.DefaultProjectID = c.projectID @@ -725,12 +725,12 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec // create the job in the metadata table jobStatement := fmt.Sprintf( - "INSERT INTO %s.%s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);", - c.datasetID, MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID) + "INSERT INTO %s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);", + MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID) if hasJob { jobStatement = fmt.Sprintf( - "UPDATE %s.%s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';", - c.datasetID, MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName) + "UPDATE %s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';", + MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName) } return jobStatement, nil @@ -739,8 +739,8 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec // metadataHasJob checks if the metadata table has the given job. func (c *BigQueryConnector) metadataHasJob(jobName string) (bool, error) { checkStmt := fmt.Sprintf( - "SELECT COUNT(*) FROM %s.%s WHERE mirror_job_name = '%s'", - c.datasetID, MirrorJobsTable, jobName) + "SELECT COUNT(*) FROM %s WHERE mirror_job_name = '%s'", + MirrorJobsTable, jobName) q := c.client.Query(checkStmt) q.DefaultProjectID = c.projectID @@ -878,7 +878,7 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error { } // deleting job from metadata table - query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) + query := fmt.Sprintf("DELETE FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) queryHandler := c.client.Query(query) queryHandler.DefaultProjectID = c.projectID queryHandler.DefaultDatasetID = c.datasetID @@ -1017,12 +1017,16 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro } type datasetTable struct { + project string dataset string table string } func (d *datasetTable) string() string { - return fmt.Sprintf("%s.%s", d.dataset, d.table) + if d.project == "" { + return fmt.Sprintf("%s.%s", d.dataset, d.table) + } + return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table) } func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) { @@ -1037,6 +1041,12 @@ func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTab dataset: parts[0], table: parts[1], }, nil + } else if len(parts) == 3 { + return &datasetTable{ + project: parts[0], + dataset: parts[1], + table: parts[2], + }, nil } else { return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName) } diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index 5878dd2bd0..c0aafe045f 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -112,10 +112,10 @@ func (c *BigQueryConnector) createMetadataInsertStatement( partitionJSON := string(pbytes) insertMetadataStmt := fmt.Sprintf( - "INSERT INTO %s._peerdb_query_replication_metadata"+ + "INSERT INTO _peerdb_query_replication_metadata"+ "(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) "+ "VALUES ('%s', '%s', JSON '%s', TIMESTAMP('%s'), CURRENT_TIMESTAMP());", - c.datasetID, jobName, partition.PartitionId, + jobName, partition.PartitionId, partitionJSON, startTime.Format(time.RFC3339)) return insertMetadataStmt, nil @@ -170,8 +170,8 @@ func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e func (c *BigQueryConnector) isPartitionSynced(partitionID string) (bool, error) { queryString := fmt.Sprintf( - "SELECT COUNT(*) FROM %s._peerdb_query_replication_metadata WHERE partitionID = '%s';", - c.datasetID, partitionID, + "SELECT COUNT(*) FROM _peerdb_query_replication_metadata WHERE partitionID = '%s';", + partitionID, ) query := c.client.Query(queryString) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 7e6768963a..14fe8ef5f3 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -58,6 +58,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10)) numRecords, err := s.writeToStage(strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, &datasetTable{ + project: s.connector.projectID, dataset: s.connector.datasetID, table: stagingTable, }, stream, req.FlowJobName) @@ -67,8 +68,8 @@ func (s *QRepAvroSyncMethod) SyncRecords( bqClient := s.connector.client datasetID := s.connector.datasetID - insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;", - datasetID, rawTableName, datasetID, stagingTable) + insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;", + rawTableName, stagingTable) lastCP, err := req.Records.GetLastCheckpoint() if err != nil { @@ -171,6 +172,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( // create a staging table name with partitionID replace hyphens with underscores dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName) stagingDatasetTable := &datasetTable{ + project: s.connector.projectID, dataset: dstDatasetTable.dataset, table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, strings.ReplaceAll(partition.PartitionId, "-", "_")), @@ -198,7 +200,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( } // Insert the records from the staging table into the destination table insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;", - dstDatasetTable.string(), selector, stagingDatasetTable.string()) + dstTableName, selector, stagingDatasetTable) insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) if err != nil { @@ -229,7 +231,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( flowLog) } - slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog) + slog.Info(fmt.Sprintf("loaded stage into %s", dstTableName), flowLog) return numRecords, nil }