Skip to content

Commit

Permalink
BigQuery cross project support: unqualify tables for cdc/qrep (#1147)
Browse files Browse the repository at this point in the history
`DefaultDatasetID` and `DefaultProjectID` work only for unqualified
tables. This PR unqualifies our tables in bigquery queries for qrep and
cdc.

Will update our docs to explicitly say that we support cross-project
mode only for CDC and QRep (BQ). This is yet to be integrated for resync
mirror

This PR has been functionally tested with two projects.

Follow up to #1073
  • Loading branch information
Amogh-Bharadwaj authored Jan 25, 2024
1 parent 8eb2567 commit 44eff42
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
58 changes: 34 additions & 24 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string,
for _, addedColumn := range schemaDelta.AddedColumns {
dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName)
query := c.client.Query(fmt.Sprintf(
"ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset,
"ALTER TABLE %s ADD COLUMN IF NOT EXISTS `%s` %s",
dstDatasetTable.table, addedColumn.ColumnName,
qValueKindToBigQueryType(addedColumn.ColumnType)))
query.DefaultProjectID = c.projectID
query.DefaultDatasetID = c.datasetID
query.DefaultDatasetID = dstDatasetTable.dataset
_, err := query.Read(c.ctx)
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
Expand Down Expand Up @@ -312,7 +312,7 @@ func (c *BigQueryConnector) SetupMetadataTables() error {
}

func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {
query := fmt.Sprintf("SELECT offset FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT offset FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
Expand All @@ -339,8 +339,7 @@ func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {

func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) error {
query := fmt.Sprintf(
"UPDATE %s.%s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
c.datasetID,
"UPDATE %s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
MirrorJobsTable,
lastOffset,
jobName,
Expand All @@ -357,8 +356,8 @@ func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) erro
}

func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
Expand All @@ -384,8 +383,8 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
}

func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) {
query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
Expand Down Expand Up @@ -422,9 +421,9 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn
rawTableName := c.getRawTableName(flowJobName)

// Prepare the query to retrieve distinct tables in that batch
query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s
query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s
WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d`,
c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -465,10 +464,10 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync
// where a placeholder value for unchanged cols can be set in DeleteRecord if there is no backfill
// we don't want these particular DeleteRecords to be used in the update statement
query := fmt.Sprintf(`SELECT _peerdb_destination_table_name,
array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s.%s
array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s
WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d AND _peerdb_record_type != 2
GROUP BY _peerdb_destination_table_name`,
c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultDatasetID = c.datasetID
Expand Down Expand Up @@ -593,6 +592,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
dstDatasetTable, _ := c.convertToDatasetTable(tableName)
mergeGen := &mergeStmtGenerator{
rawDatasetTable: &datasetTable{
project: c.projectID,
dataset: c.datasetID,
table: rawTableName,
},
Expand All @@ -615,7 +615,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
i+1, len(mergeStmts), tableName))
q := c.client.Query(mergeStmt)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
q.DefaultDatasetID = dstDatasetTable.dataset
_, err = q.Read(c.ctx)
if err != nil {
return nil, fmt.Errorf("failed to execute merge statement %s: %v", mergeStmt, err)
Expand All @@ -624,8 +624,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
}
// update metadata to make the last normalized batch id to the recent last sync batch id.
updateMetadataStmt := fmt.Sprintf(
"UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
c.datasetID, MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName)
"UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName)

query := c.client.Query(updateMetadataStmt)
query.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -725,12 +725,12 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec

// create the job in the metadata table
jobStatement := fmt.Sprintf(
"INSERT INTO %s.%s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
c.datasetID, MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
"INSERT INTO %s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
if hasJob {
jobStatement = fmt.Sprintf(
"UPDATE %s.%s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
c.datasetID, MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
"UPDATE %s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
}

return jobStatement, nil
Expand All @@ -739,8 +739,8 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec
// metadataHasJob checks if the metadata table has the given job.
func (c *BigQueryConnector) metadataHasJob(jobName string) (bool, error) {
checkStmt := fmt.Sprintf(
"SELECT COUNT(*) FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
"SELECT COUNT(*) FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)

q := c.client.Query(checkStmt)
q.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -878,7 +878,7 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error {
}

// deleting job from metadata table
query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("DELETE FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
queryHandler := c.client.Query(query)
queryHandler.DefaultProjectID = c.projectID
queryHandler.DefaultDatasetID = c.datasetID
Expand Down Expand Up @@ -1017,12 +1017,16 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro
}

type datasetTable struct {
project string
dataset string
table string
}

func (d *datasetTable) string() string {
return fmt.Sprintf("%s.%s", d.dataset, d.table)
if d.project == "" {
return fmt.Sprintf("%s.%s", d.dataset, d.table)
}
return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table)
}

func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) {
Expand All @@ -1037,6 +1041,12 @@ func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTab
dataset: parts[0],
table: parts[1],
}, nil
} else if len(parts) == 3 {
return &datasetTable{
project: parts[0],
dataset: parts[1],
table: parts[2],
}, nil
} else {
return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName)
}
Expand Down
20 changes: 8 additions & 12 deletions flow/connectors/bigquery/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {
)

// normalize anything between last normalized batch id to last sync batchid
return fmt.Sprintf(`WITH _f AS
(SELECT %s FROM %s WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND
_peerdb_destination_table_name='%s')`,
return fmt.Sprintf("WITH _f AS "+
"(SELECT %s FROM `%s` WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND "+
"_peerdb_destination_table_name='%s')",
strings.Join(flattenedProjs, ","), m.rawDatasetTable.string(), m.normalizeBatchID,
m.syncBatchID, m.dstTableName)
}
Expand Down Expand Up @@ -169,15 +169,11 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s
}
}

return fmt.Sprintf(`
MERGE %s _t USING(%s,%s) _d
ON %s
WHEN NOT MATCHED AND _d._rt!=2 THEN
INSERT (%s) VALUES(%s)
%s
WHEN MATCHED AND _d._rt=2 THEN
%s;
`, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(),
return fmt.Sprintf("MERGE `%s` _t USING(%s,%s) _d"+
" ON %s WHEN NOT MATCHED AND _d._rt!=2 THEN "+
"INSERT (%s) VALUES(%s) "+
"%s WHEN MATCHED AND _d._rt=2 THEN %s;",
m.dstDatasetTable.table, m.generateFlattenedCTE(), m.generateDeDupedCTE(),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)
}

Expand Down
8 changes: 4 additions & 4 deletions flow/connectors/bigquery/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ func (c *BigQueryConnector) createMetadataInsertStatement(
partitionJSON := string(pbytes)

insertMetadataStmt := fmt.Sprintf(
"INSERT INTO %s._peerdb_query_replication_metadata"+
"INSERT INTO _peerdb_query_replication_metadata"+
"(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) "+
"VALUES ('%s', '%s', JSON '%s', TIMESTAMP('%s'), CURRENT_TIMESTAMP());",
c.datasetID, jobName, partition.PartitionId,
jobName, partition.PartitionId,
partitionJSON, startTime.Format(time.RFC3339))

return insertMetadataStmt, nil
Expand Down Expand Up @@ -170,8 +170,8 @@ func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e

func (c *BigQueryConnector) isPartitionSynced(partitionID string) (bool, error) {
queryString := fmt.Sprintf(
"SELECT COUNT(*) FROM %s._peerdb_query_replication_metadata WHERE partitionID = '%s';",
c.datasetID, partitionID,
"SELECT COUNT(*) FROM _peerdb_query_replication_metadata WHERE partitionID = '%s';",
partitionID,
)

query := c.client.Query(queryString)
Expand Down
10 changes: 6 additions & 4 deletions flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (s *QRepAvroSyncMethod) SyncRecords(
stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10))
numRecords, err := s.writeToStage(strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
&datasetTable{
project: s.connector.projectID,
dataset: s.connector.datasetID,
table: stagingTable,
}, stream, req.FlowJobName)
Expand All @@ -67,8 +68,8 @@ func (s *QRepAvroSyncMethod) SyncRecords(

bqClient := s.connector.client
datasetID := s.connector.datasetID
insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;",
datasetID, rawTableName, datasetID, stagingTable)
insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;",
rawTableName, stagingTable)

lastCP, err := req.Records.GetLastCheckpoint()
if err != nil {
Expand Down Expand Up @@ -171,6 +172,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
// create a staging table name with partitionID replace hyphens with underscores
dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName)
stagingDatasetTable := &datasetTable{
project: s.connector.projectID,
dataset: dstDatasetTable.dataset,
table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table,
strings.ReplaceAll(partition.PartitionId, "-", "_")),
Expand Down Expand Up @@ -198,7 +200,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
}
// Insert the records from the staging table into the destination table
insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;",
dstDatasetTable.string(), selector, stagingDatasetTable.string())
dstTableName, selector, stagingDatasetTable.string())

insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime)
if err != nil {
Expand Down Expand Up @@ -229,7 +231,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
flowLog)
}

slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog)
slog.Info(fmt.Sprintf("loaded stage into %s", dstTableName), flowLog)
return numRecords, nil
}

Expand Down

0 comments on commit 44eff42

Please sign in to comment.