Skip to content

Commit

Permalink
partial mixed case support for SF
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Dec 28, 2023
1 parent c9e5e90 commit 29d4770
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 103 deletions.
16 changes: 16 additions & 0 deletions flow/connectors/snowflake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package connsnowflake
import (
"context"
"fmt"
"strings"
"time"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -84,3 +85,18 @@ func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) {
}
return totalRecords, nil
}

func SnowflakeIdentifierNormalize(identifier string) string {
// https://www.alberton.info/dbms_identifiers_and_case_sensitivity.html
// Snowflake follows the SQL standard, but Postgres does the opposite.
// Ergo, we suffer.
if utils.IsLower(identifier) {
return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier))
}
return fmt.Sprintf(`"%s"`, identifier)
}

func snowflakeSchemaTableNormalize(schemaTable *utils.SchemaTable) string {
return fmt.Sprintf(`%s.%s`, SnowflakeIdentifierNormalize(schemaTable.Schema),
SnowflakeIdentifierNormalize(schemaTable.Table))
}
11 changes: 8 additions & 3 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ func (c *SnowflakeConnector) createMetadataInsertStatement(
}

func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) {
schematable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("failed to parse table '%s'", tableName)
}

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT *
FROM %s
LIMIT 0
`, tableName)
`, snowflakeSchemaTableNormalize(schematable))

rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
Expand Down Expand Up @@ -296,10 +301,10 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnIn
}
defer rows.Close()

var colName pgtype.Text
var colType pgtype.Text
columnMap := map[string]string{}
for rows.Next() {
var colName pgtype.Text
var colType pgtype.Text
if err := rows.Scan(&colName, &colType); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
Expand Down
46 changes: 24 additions & 22 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,30 +302,36 @@ func (c *SnowflakeConnector) GetCopyTransformation(
) (*CopyInfo, error) {
colInfo, colsErr := c.getColsFromTable(dstTableName)
if colsErr != nil {
return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr)
return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr)
}

transformations := make([]string, 0, len(colInfo.ColumnMap))
columnOrder := make([]string, 0, len(colInfo.ColumnMap))
for colName, colType := range colInfo.ColumnMap {
columnOrder = append(columnOrder, fmt.Sprintf("\"%s\"", colName))
if colName == syncedAtCol {
transformations = append(transformations, fmt.Sprintf("CURRENT_TIMESTAMP AS \"%s\"", colName))
for avroColName, colType := range colInfo.ColumnMap {
normalizedColName := SnowflakeIdentifierNormalize(avroColName)
columnOrder = append(columnOrder, normalizedColName)
if avroColName == syncedAtCol {
transformations = append(transformations, fmt.Sprintf("CURRENT_TIMESTAMP AS %s", normalizedColName))
continue
}

if utils.IsUpper(avroColName) {
avroColName = strings.ToLower(avroColName)
}
// Avro files are written with lowercase in mind, so don't normalize it like everything else
switch colType {
case "GEOGRAPHY":
transformations = append(transformations,
fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string, true) AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string, true) AS %s", avroColName, normalizedColName))
case "GEOMETRY":
transformations = append(transformations,
fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string, true) AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string, true) AS %s", avroColName, normalizedColName))
case "NUMBER":
transformations = append(transformations,
fmt.Sprintf("$1:\"%s\" AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("$1:\"%s\" AS %s", avroColName, normalizedColName))
default:
transformations = append(transformations,
fmt.Sprintf("($1:\"%s\")::%s AS \"%s\"", strings.ToLower(colName), colType, colName))
fmt.Sprintf("($1:\"%s\")::%s AS %s", avroColName, colType, normalizedColName))
}
}
transformationSQL := strings.Join(transformations, ",")
Expand Down Expand Up @@ -361,14 +367,12 @@ func CopyStageToDestination(
if err != nil {
return fmt.Errorf("failed to get copy transformation: %w", err)
}
switch appendMode {
case true:
if appendMode {
err := writeHandler.HandleAppendMode(copyTransformation)
if err != nil {
return fmt.Errorf("failed to handle append mode: %w", err)
}

case false:
} else {
upsertKeyCols := config.WriteMode.UpsertKeyColumns
err := writeHandler.HandleUpsertMode(allCols, upsertKeyCols, config.WatermarkColumn,
config.FlowJobName, copyTransformation)
Expand Down Expand Up @@ -428,9 +432,11 @@ func NewSnowflakeAvroWriteHandler(
func (s *SnowflakeAvroWriteHandler) HandleAppendMode(
copyInfo *CopyInfo,
) error {
parsedDstTable, _ := utils.ParseSchemaTable(s.dstTableName)
//nolint:gosec
copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s",
s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
snowflakeSchemaTableNormalize(parsedDstTable), copyInfo.columnsSQL,
copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
s.connector.logger.Info("running copy command: " + copyCmd)
_, err := s.connector.database.ExecContext(s.connector.ctx, copyCmd)
if err != nil {
Expand All @@ -441,13 +447,12 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode(
return nil
}

func GenerateMergeCommand(
func generateUpsertMergeCommand(
allCols []string,
upsertKeyCols []string,
watermarkCol string,
tempTableName string,
dstTable string,
) (string, error) {
) string {
// all cols are acquired from snowflake schema, so let us try to make upsert key cols match the case
// and also the watermark col, then the quoting should be fine
caseMatchedCols := map[string]string{}
Expand Down Expand Up @@ -495,7 +500,7 @@ func GenerateMergeCommand(
`, dstTable, selectCmd, upsertKeyClause,
updateSetClause, insertColumnsClause, insertValuesClause)

return mergeCmd, nil
return mergeCmd
}

// HandleUpsertMode handles the upsert mode
Expand Down Expand Up @@ -530,10 +535,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
}
s.connector.logger.Info("copied file from stage " + s.stage + " to temp table " + tempTableName)

mergeCmd, err := GenerateMergeCommand(allCols, upsertKeyCols, watermarkCol, tempTableName, s.dstTableName)
if err != nil {
return fmt.Errorf("failed to generate merge command: %w", err)
}
mergeCmd := generateUpsertMergeCommand(allCols, upsertKeyCols, tempTableName, s.dstTableName)

startTime := time.Now()
rows, err := s.connector.database.ExecContext(s.connector.ctx, mergeCmd)
Expand Down
53 changes: 30 additions & 23 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (c *SnowflakeConnector) SetupNormalizedTables(
}

normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(
tableIdentifier, tableSchema, req.SoftDeleteColName, req.SyncedAtColName)
normalizedSchemaTable, tableSchema, req.SoftDeleteColName, req.SyncedAtColName)
_, err = c.database.ExecContext(c.ctx, normalizedTableCreateSQL)
if err != nil {
return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err)
Expand Down Expand Up @@ -562,8 +562,8 @@ func (c *SnowflakeConnector) syncRecordsViaAvro(
qrepConfig := &protos.QRepConfig{
StagingPath: "",
FlowJobName: req.FlowJobName,
DestinationTableIdentifier: fmt.Sprintf("%s.%s", c.metadataSchema,
rawTableIdentifier),
DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.metadataSchema,
rawTableIdentifier)),
}
avroSyncer := NewSnowflakeAvroSyncMethod(qrepConfig, c)
destinationTableSchema, err := c.getTableSchema(qrepConfig.DestinationTableIdentifier)
Expand Down Expand Up @@ -759,50 +759,50 @@ func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableId
}

func generateCreateTableSQLForNormalizedTable(
sourceTableIdentifier string,
dstSchemaTable *utils.SchemaTable,
sourceTableSchema *protos.TableSchema,
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)+2)
for columnName, genericColumnType := range sourceTableSchema.Columns {
columnNameUpper := strings.ToUpper(columnName)
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfColType))
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s,`, normalizedColName, sfColType))
}

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" BOOLEAN DEFAULT FALSE,`, softDeleteColName))
fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE,`, softDeleteColName))
}

// add a _peerdb_synced column to the normalized table
// this is a timestamp column that is used to mark records as synced
// default value is the current timestamp (snowflake)
if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
}

// add composite primary key to the table
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsUpperQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
normalizedPrimaryKeyCols := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted,
fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol)))
normalizedPrimaryKeyCols = append(normalizedPrimaryKeyCols,
SnowflakeIdentifierNormalize(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ",")))
strings.TrimSuffix(strings.Join(normalizedPrimaryKeyCols, ","), ",")))
}

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
return fmt.Sprintf(createNormalizedTableSQL, snowflakeSchemaTableNormalize(dstSchemaTable),
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}

Expand All @@ -821,6 +821,10 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
normalizeReq *model.NormalizeRecordsRequest,
) (int64, error) {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
parsedDstTable, err := utils.ParseSchemaTable(destinationTableIdentifier)
if err != nil {
return 0, fmt.Errorf("unable to parse destination table '%s'", parsedDstTable)
}
columnNames := maps.Keys(normalizedTableSchema.Columns)

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
Expand All @@ -832,7 +836,7 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
genericColumnType, err)
}

targetColumnName := fmt.Sprintf(`"%s"`, strings.ToUpper(columnName))
targetColumnName := SnowflakeIdentifierNormalize(columnName)
switch qvalue.QValueKind(genericColumnType) {
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+
Expand Down Expand Up @@ -865,7 +869,7 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(

quotedUpperColNames := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
quotedUpperColNames = append(quotedUpperColNames, fmt.Sprintf(`"%s"`, strings.ToUpper(columnName)))
quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName))
}
// append synced_at column
quotedUpperColNames = append(quotedUpperColNames,
Expand All @@ -876,8 +880,8 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(

insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
quotedUpperColumnName := fmt.Sprintf(`"%s"`, strings.ToUpper(columnName))
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", quotedUpperColumnName))
normalizedColName := SnowflakeIdentifierNormalize(columnName)
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName))
}
// fill in synced_at column
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
Expand All @@ -899,10 +903,13 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
}
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

normalizedpkeyColsArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns {
normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName)
normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName)
pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s",
pkeyColName, pkeyColName))
normalizedPkeyColName, normalizedPkeyColName))
}
// TARGET.<pkey1> = SOURCE.<pkey1> AND TARGET.<pkey2> = SOURCE.<pkey2> ...
pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ")
Expand All @@ -916,9 +923,9 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
}
}

mergeStatement := fmt.Sprintf(mergeStatementSQL, destinationTableIdentifier, toVariantColumnName,
rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")),
mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable),
toVariantColumnName, rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

startTime := time.Now()
Expand Down Expand Up @@ -1045,8 +1052,8 @@ func (c *SnowflakeConnector) generateUpdateStatements(
otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
tmpArray := make([]string, 0, len(otherCols)+2)
for _, colName := range otherCols {
quotedUpperColName := fmt.Sprintf(`"%s"`, strings.ToUpper(colName))
tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", quotedUpperColName, quotedUpperColName))
normalizedColName := SnowflakeIdentifierNormalize(colName)
tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName))
}

// set the synced at column to the current timestamp
Expand Down
21 changes: 21 additions & 0 deletions flow/connectors/utils/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"fmt"
"strings"
"unicode"
)

func QuoteIdentifier(identifier string) string {
Expand All @@ -28,3 +29,23 @@ func ParseSchemaTable(tableName string) (*SchemaTable, error) {

return &SchemaTable{schema, table}, nil
}

// I think these only work with ASCII?
func IsUpper(s string) bool {
for _, r := range s {
if !unicode.IsUpper(r) && unicode.IsLetter(r) {
return false
}
}
return true
}

// I think these only work with ASCII?
func IsLower(s string) bool {
for _, r := range s {
if !unicode.IsLower(r) && unicode.IsLetter(r) {
return false
}
}
return true
}
Loading

0 comments on commit 29d4770

Please sign in to comment.