Skip to content

Commit

Permalink
mixed case table and column name support for CDC mirrors in Snowflake…
Browse files Browse the repository at this point in the history
… and Postgres
  • Loading branch information
heavycrystal committed Oct 29, 2023
1 parent 38402a1 commit d2c930d
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 456 deletions.
10 changes: 6 additions & 4 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
}
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
updateColumnsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
Expand All @@ -513,11 +514,11 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie

fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL,
rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL,
strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL)
rawTableIdentifier, parsedDstTable.String(), deleteWhereClauseSQL)

return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}
Expand All @@ -529,6 +530,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
}
parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier)

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
primaryKeyColumnCasts := make(map[string]string)
Expand Down Expand Up @@ -560,7 +562,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","),
internalSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL,
internalSchema, rawTableIdentifier, parsedDstTable.String(), flattenedCastsSQL,
strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements)
}

Expand Down Expand Up @@ -603,7 +605,7 @@ func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error)
err := row.Scan(&count)
if err != nil {
log.WithFields(log.Fields{
"table": table,
"table": parsedTable.String(),
}).Errorf("error while scanning row: %v", err)
return fmt.Errorf("error while scanning row: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,11 +643,11 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}()

for tableIdentifier, tableSchema := range req.TableNameSchemaMapping {
normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier)
parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
tableAlreadyExists, err := c.tableExists(normalizedTableNameComponents)
tableAlreadyExists, err := c.tableExists(parsedNormalizedTable)
if err != nil {
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
Expand All @@ -657,7 +657,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}

// convert the column names and types to Postgres types
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(tableIdentifier, tableSchema)
normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(parsedNormalizedTable.String(),
tableSchema)
_, err = createNormalizedTablesTx.Exec(c.ctx, normalizedTableCreateSQL)
if err != nil {
return nil, fmt.Errorf("error while creating normalized table: %w", err)
Expand Down
4 changes: 3 additions & 1 deletion flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
return -1, fmt.Errorf("failed to commit transaction: %v", err)
}

totalRecordsAtTarget, err := s.connector.getApproxTableCounts([]string{dstTableName.String()})
// a conversion to SchemaTable happens in this function, so we cannot do the conversion here
totalRecordsAtTarget, err := s.connector.getApproxTableCounts([]string{
fmt.Sprintf("%s.%s", dstTableName.Schema, dstTableName.Table)})
if err != nil {
return -1, fmt.Errorf("failed to get total records at target: %v", err)
}
Expand Down
22 changes: 20 additions & 2 deletions flow/connectors/snowflake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package connsnowflake
import (
"context"
"fmt"
"strings"
"time"

"github.com/jmoiron/sqlx"
"github.com/snowflakedb/gosnowflake"

peersql "github.com/PeerDB-io/peer-flow/connectors/sql"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
util "github.com/PeerDB-io/peer-flow/utils"
)
Expand Down Expand Up @@ -67,12 +69,13 @@ func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*S
func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) {
var totalRecords int64
for _, table := range tables {
_, err := parseTableName(table)
parsedSchemaTable, err := utils.ParseSchemaTable(table)
if err != nil {
return 0, fmt.Errorf("failed to parse table name %s: %w", table, err)
}
//nolint:gosec
row := c.database.QueryRowContext(c.ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
row := c.database.QueryRowContext(c.ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s",
snowflakeSchemaTableNormalize(parsedSchemaTable)))
var count int64
err = row.Scan(&count)
if err != nil {
Expand All @@ -82,3 +85,18 @@ func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) {
}
return totalRecords, nil
}

func snowflakeIdentifierNormalize(identifier string) string {
// https://www.alberton.info/dbms_identifiers_and_case_sensitivity.html
// Snowflake follows the SQL standard, but Postgres does the opposite.
// Ergo, we suffer.
if strings.ToLower(identifier) == identifier {
return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier))
}
return fmt.Sprintf(`"%s"`, identifier)
}

func snowflakeSchemaTableNormalize(schemaTable *utils.SchemaTable) string {
return fmt.Sprintf(`%s.%s`, snowflakeIdentifierNormalize(schemaTable.Schema),
snowflakeIdentifierNormalize(schemaTable.Table))
}
34 changes: 20 additions & 14 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (c *SnowflakeConnector) SyncQRepRecords(
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
return 0, fmt.Errorf("multi-insert sync mode not supported for snowflake")
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
avroSync := NewSnowflakeAvroSyncMethod(config, c)
avroSync := NewSnowflakeAvroSyncHandler(config, c)
return avroSync.SyncQRepRecords(config, partition, tblSchema, stream)
default:
return 0, fmt.Errorf("unsupported sync mode: %s", syncMode)
Expand Down Expand Up @@ -86,12 +86,17 @@ func (c *SnowflakeConnector) createMetadataInsertStatement(
}

func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) {
parsedTableName, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("failed to parse table '%s'", tableName)
}

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT *
FROM %s
LIMIT 0
`, tableName)
`, snowflakeSchemaTableNormalize(parsedTableName))

rows, err := c.database.Query(queryString)
if err != nil {
Expand Down Expand Up @@ -253,16 +258,15 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
return fmt.Errorf("multi-insert sync mode not supported for snowflake")
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
colInfo, err := c.getColsFromTable(destTable)
colInfo, err := c.getColsFromTable(destTable, true)
if err != nil {
log.WithFields(log.Fields{
"flowName": config.FlowJobName,
}).Errorf("failed to get columns from table %s: %v", destTable, err)
return fmt.Errorf("failed to get columns from table %s: %w", destTable, err)
}

allCols := colInfo.Columns
err = CopyStageToDestination(c, config, destTable, stageName, allCols)
err = copyStageToDestination(c, config, destTable, stageName, colInfo)
if err != nil {
log.WithFields(log.Fields{
"flowName": config.FlowJobName,
Expand All @@ -284,23 +288,20 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error {
return c.dropStage(config.StagingPath, config.FlowJobName)
}

func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) {
func (c *SnowflakeConnector) getColsFromTable(tableName string,
correctForAvro bool) (*model.ColumnInformation, error) {
// parse the table name to get the schema and table name
components, err := parseTableName(tableName)
components, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("failed to parse table name: %w", err)
}

// convert tableIdentifier and schemaIdentifier to upper case
components.tableIdentifier = strings.ToUpper(components.tableIdentifier)
components.schemaIdentifier = strings.ToUpper(components.schemaIdentifier)

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT column_name, data_type
FROM information_schema.columns
WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s'
`, components.tableIdentifier, components.schemaIdentifier)
`, strings.ToUpper(components.Table), strings.ToUpper(components.Schema))

rows, err := c.database.Query(queryString)
if err != nil {
Expand All @@ -309,12 +310,17 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnIn
defer rows.Close()

columnMap := map[string]string{}
var colName string
var colType string
for rows.Next() {
var colName string
var colType string
if err := rows.Scan(&colName, &colType); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
// Avro file was written with caseless identifiers being lowercase, as the information is fetched from Postgres
// Snowflake retrieves the column information with caseless identifiers being UPPERCASE
if correctForAvro && strings.ToUpper(colName) == colName {
colName = strings.ToLower(colName)
}
columnMap[colName] = colType
}
var cols []string
Expand Down
Loading

0 comments on commit d2c930d

Please sign in to comment.