diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index eb8ebb6177..59a269a092 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -100,6 +100,44 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { m.syncBatchID, m.dstTableName) } +// This function is to support datatypes like JSON which cannot be partitioned by or compared by BigQuery +func (m *mergeStmtGenerator) transformedPkeyStrings(forPartition bool) []string { + pkeys := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + columnNameTypeMap := make(map[string]qvalue.QValueKind, len(m.normalizedTableSchema.Columns)) + for _, col := range m.normalizedTableSchema.Columns { + columnNameTypeMap[col.Name] = qvalue.QValueKind(col.Type) + } + + for _, pkeyCol := range m.normalizedTableSchema.PrimaryKeyColumns { + pkeyColType, ok := columnNameTypeMap[pkeyCol] + if !ok { + continue + } + switch pkeyColType { + case qvalue.QValueKindJSON: + if forPartition { + pkeys = append(pkeys, fmt.Sprintf("TO_JSON_STRING(%s)", m.shortColumn[pkeyCol])) + } else { + pkeys = append(pkeys, fmt.Sprintf("TO_JSON_STRING(_t.`%s`)=TO_JSON_STRING(_d.%s)", + pkeyCol, m.shortColumn[pkeyCol])) + } + case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64: + if forPartition { + pkeys = append(pkeys, fmt.Sprintf("CAST(%s as STRING)", m.shortColumn[pkeyCol])) + } else { + pkeys = append(pkeys, fmt.Sprintf("_t.`%s`=_d.%s", pkeyCol, m.shortColumn[pkeyCol])) + } + default: + if forPartition { + pkeys = append(pkeys, m.shortColumn[pkeyCol]) + } else { + pkeys = append(pkeys, fmt.Sprintf("_t.`%s`=_d.%s", pkeyCol, m.shortColumn[pkeyCol])) + } + } + } + return pkeys +} + // generateDeDupedCTE generates a de-duped CTE. func (m *mergeStmtGenerator) generateDeDupedCTE() string { const cte = `_dd AS ( @@ -111,13 +149,8 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { WHERE _peerdb_rank=1 ) SELECT * FROM _dd` - shortPkeys := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyCol := range m.normalizedTableSchema.PrimaryKeyColumns { - shortPkeys = append(shortPkeys, m.shortColumn[pkeyCol]) - } - - pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(shortPkeys, - ", '_peerdb_concat_', ")) + shortPkeys := m.transformedPkeyStrings(true) + pkeyColsStr := strings.Join(shortPkeys, ",") return fmt.Sprintf(cte, pkeyColsStr) } @@ -151,11 +184,7 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s } updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { - pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("_t.%s=_d.%s", - pkeyColName, m.shortColumn[pkeyColName])) - } + pkeySelectSQLArray := m.transformedPkeyStrings(false) // t. = d. AND t. = d. ... pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index 9d0b8f6a57..74a1949ff7 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -1514,3 +1514,56 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_Insert_After_Delete() { require.NoError(s.t, err) require.Equal(s.t, int64(0), numNewRows) } + +func (s PeerFlowE2ETestSuiteBQ) Test_JSON_PKey_BQ() { + env := e2e.NewTemporalTestWorkflowEnvironment(s.t) + + srcTableName := s.attachSchemaSuffix("test_json_pkey_bq") + dstTableName := "test_json_pkey_bq" + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL NOT NULL, + j JSON NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL + ); + `, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + ALTER TABLE %s REPLICA IDENTITY FULL + `, srcTableName)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_json_pkey_flow"), + TableNameMapping: map[string]string{srcTableName: dstTableName}, + Destination: s.bqHelper.Peer, + CdcStagingPath: "", + } + + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.MaxBatchSize = 100 + + go func() { + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + // insert 10 rows into the source table + for i := range 10 { + testKey := fmt.Sprintf("test_key_%d", i) + testValue := fmt.Sprintf("test_value_%d", i) + testJson := `'{"name":"jack", "age":12, "spouse":null}'::json` + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(key, value, j) VALUES ($1, $2, %s) + `, srcTableName, testJson), testKey, testValue) + e2e.EnvNoError(s.t, env, err) + } + s.t.Log("Inserted 10 rows into the source table") + + e2e.EnvWaitForEqualTables(env, s, "normalize inserts", dstTableName, "id,key,value,j") + env.CancelWorkflow() + }() + + env.ExecuteWorkflow(peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.RequireEnvCanceled(s.t, env) +}