Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partial mixed case support for SF #919

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flow/connectors/snowflake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package connsnowflake
import (
"context"
"fmt"
"strings"
"time"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -84,3 +85,18 @@ func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) {
}
return totalRecords, nil
}

func SnowflakeIdentifierNormalize(identifier string) string {
// https://www.alberton.info/dbms_identifiers_and_case_sensitivity.html
// Snowflake follows the SQL standard, but Postgres does the opposite.
// Ergo, we suffer.
if utils.IsLower(identifier) {
return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier))
}
return fmt.Sprintf(`"%s"`, identifier)
}

func snowflakeSchemaTableNormalize(schemaTable *utils.SchemaTable) string {
return fmt.Sprintf(`%s.%s`, SnowflakeIdentifierNormalize(schemaTable.Schema),
SnowflakeIdentifierNormalize(schemaTable.Table))
}
11 changes: 8 additions & 3 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ func (c *SnowflakeConnector) createMetadataInsertStatement(
}

func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) {
schematable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("failed to parse table '%s'", tableName)
}

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT *
FROM %s
LIMIT 0
`, tableName)
`, snowflakeSchemaTableNormalize(schematable))

rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
Expand Down Expand Up @@ -296,10 +301,10 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnIn
}
defer rows.Close()

var colName pgtype.Text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were these moved outside the loop?

Copy link
Contributor

@serprex serprex Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd guess the goal is to avoid clearing an empty stack slot each loop, which may also be boxed depending how go analyzes Scan. This refactor has been appearing in a few previous PRs

var colType pgtype.Text
columnMap := map[string]string{}
for rows.Next() {
var colName pgtype.Text
var colType pgtype.Text
if err := rows.Scan(&colName, &colType); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
Expand Down
46 changes: 24 additions & 22 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,30 +302,36 @@ func (c *SnowflakeConnector) GetCopyTransformation(
) (*CopyInfo, error) {
colInfo, colsErr := c.getColsFromTable(dstTableName)
if colsErr != nil {
return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr)
return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr)
}

transformations := make([]string, 0, len(colInfo.ColumnMap))
columnOrder := make([]string, 0, len(colInfo.ColumnMap))
for colName, colType := range colInfo.ColumnMap {
columnOrder = append(columnOrder, fmt.Sprintf("\"%s\"", colName))
if colName == syncedAtCol {
transformations = append(transformations, fmt.Sprintf("CURRENT_TIMESTAMP AS \"%s\"", colName))
for avroColName, colType := range colInfo.ColumnMap {
normalizedColName := SnowflakeIdentifierNormalize(avroColName)
columnOrder = append(columnOrder, normalizedColName)
if avroColName == syncedAtCol {
transformations = append(transformations, fmt.Sprintf("CURRENT_TIMESTAMP AS %s", normalizedColName))
continue
}

if utils.IsUpper(avroColName) {
avroColName = strings.ToLower(avroColName)
}
// Avro files are written with lowercase in mind, so don't normalize it like everything else
switch colType {
case "GEOGRAPHY":
transformations = append(transformations,
fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string, true) AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("TO_GEOGRAPHY($1:\"%s\"::string, true) AS %s", avroColName, normalizedColName))
case "GEOMETRY":
transformations = append(transformations,
fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string, true) AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("TO_GEOMETRY($1:\"%s\"::string, true) AS %s", avroColName, normalizedColName))
case "NUMBER":
transformations = append(transformations,
fmt.Sprintf("$1:\"%s\" AS \"%s\"", strings.ToLower(colName), colName))
fmt.Sprintf("$1:\"%s\" AS %s", avroColName, normalizedColName))
default:
transformations = append(transformations,
fmt.Sprintf("($1:\"%s\")::%s AS \"%s\"", strings.ToLower(colName), colType, colName))
fmt.Sprintf("($1:\"%s\")::%s AS %s", avroColName, colType, normalizedColName))
}
}
transformationSQL := strings.Join(transformations, ",")
Expand Down Expand Up @@ -361,14 +367,12 @@ func CopyStageToDestination(
if err != nil {
return fmt.Errorf("failed to get copy transformation: %w", err)
}
switch appendMode {
case true:
if appendMode {
err := writeHandler.HandleAppendMode(copyTransformation)
if err != nil {
return fmt.Errorf("failed to handle append mode: %w", err)
}

case false:
} else {
upsertKeyCols := config.WriteMode.UpsertKeyColumns
err := writeHandler.HandleUpsertMode(allCols, upsertKeyCols, config.WatermarkColumn,
config.FlowJobName, copyTransformation)
Expand Down Expand Up @@ -428,9 +432,11 @@ func NewSnowflakeAvroWriteHandler(
func (s *SnowflakeAvroWriteHandler) HandleAppendMode(
copyInfo *CopyInfo,
) error {
parsedDstTable, _ := utils.ParseSchemaTable(s.dstTableName)
//nolint:gosec
copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s",
s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
snowflakeSchemaTableNormalize(parsedDstTable), copyInfo.columnsSQL,
copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
s.connector.logger.Info("running copy command: " + copyCmd)
_, err := s.connector.database.ExecContext(s.connector.ctx, copyCmd)
if err != nil {
Expand All @@ -441,13 +447,12 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode(
return nil
}

func GenerateMergeCommand(
func generateUpsertMergeCommand(
allCols []string,
upsertKeyCols []string,
watermarkCol string,
tempTableName string,
dstTable string,
) (string, error) {
) string {
// all cols are acquired from snowflake schema, so let us try to make upsert key cols match the case
// and also the watermark col, then the quoting should be fine
caseMatchedCols := map[string]string{}
Expand Down Expand Up @@ -495,7 +500,7 @@ func GenerateMergeCommand(
`, dstTable, selectCmd, upsertKeyClause,
updateSetClause, insertColumnsClause, insertValuesClause)

return mergeCmd, nil
return mergeCmd
}

// HandleUpsertMode handles the upsert mode
Expand Down Expand Up @@ -530,10 +535,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
}
s.connector.logger.Info("copied file from stage " + s.stage + " to temp table " + tempTableName)

mergeCmd, err := GenerateMergeCommand(allCols, upsertKeyCols, watermarkCol, tempTableName, s.dstTableName)
if err != nil {
return fmt.Errorf("failed to generate merge command: %w", err)
}
mergeCmd := generateUpsertMergeCommand(allCols, upsertKeyCols, tempTableName, s.dstTableName)

startTime := time.Now()
rows, err := s.connector.database.ExecContext(s.connector.ctx, mergeCmd)
Expand Down
53 changes: 30 additions & 23 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (c *SnowflakeConnector) SetupNormalizedTables(
}

normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(
tableIdentifier, tableSchema, req.SoftDeleteColName, req.SyncedAtColName)
normalizedSchemaTable, tableSchema, req.SoftDeleteColName, req.SyncedAtColName)
_, err = c.database.ExecContext(c.ctx, normalizedTableCreateSQL)
if err != nil {
return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err)
Expand Down Expand Up @@ -562,8 +562,8 @@ func (c *SnowflakeConnector) syncRecordsViaAvro(
qrepConfig := &protos.QRepConfig{
StagingPath: "",
FlowJobName: req.FlowJobName,
DestinationTableIdentifier: fmt.Sprintf("%s.%s", c.metadataSchema,
rawTableIdentifier),
DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.metadataSchema,
rawTableIdentifier)),
}
avroSyncer := NewSnowflakeAvroSyncMethod(qrepConfig, c)
destinationTableSchema, err := c.getTableSchema(qrepConfig.DestinationTableIdentifier)
Expand Down Expand Up @@ -759,50 +759,50 @@ func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableId
}

func generateCreateTableSQLForNormalizedTable(
sourceTableIdentifier string,
dstSchemaTable *utils.SchemaTable,
sourceTableSchema *protos.TableSchema,
softDeleteColName string,
syncedAtColName string,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)+2)
for columnName, genericColumnType := range sourceTableSchema.Columns {
columnNameUpper := strings.ToUpper(columnName)
normalizedColName := SnowflakeIdentifierNormalize(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType),
slog.Any("error", err))
continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfColType))
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s,`, normalizedColName, sfColType))
}

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" BOOLEAN DEFAULT FALSE,`, softDeleteColName))
fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE,`, softDeleteColName))
}

// add a _peerdb_synced column to the normalized table
// this is a timestamp column that is used to mark records as synced
// default value is the current timestamp (snowflake)
if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
}

// add composite primary key to the table
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsUpperQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
normalizedPrimaryKeyCols := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted,
fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol)))
normalizedPrimaryKeyCols = append(normalizedPrimaryKeyCols,
SnowflakeIdentifierNormalize(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ",")))
strings.TrimSuffix(strings.Join(normalizedPrimaryKeyCols, ","), ",")))
}

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
return fmt.Sprintf(createNormalizedTableSQL, snowflakeSchemaTableNormalize(dstSchemaTable),
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}

Expand All @@ -821,6 +821,10 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
normalizeReq *model.NormalizeRecordsRequest,
) (int64, error) {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
parsedDstTable, err := utils.ParseSchemaTable(destinationTableIdentifier)
if err != nil {
return 0, fmt.Errorf("unable to parse destination table '%s'", parsedDstTable)
}
columnNames := maps.Keys(normalizedTableSchema.Columns)

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
Expand All @@ -832,7 +836,7 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
genericColumnType, err)
}

targetColumnName := fmt.Sprintf(`"%s"`, strings.ToUpper(columnName))
targetColumnName := SnowflakeIdentifierNormalize(columnName)
switch qvalue.QValueKind(genericColumnType) {
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+
Expand Down Expand Up @@ -865,7 +869,7 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(

quotedUpperColNames := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
quotedUpperColNames = append(quotedUpperColNames, fmt.Sprintf(`"%s"`, strings.ToUpper(columnName)))
quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName))
}
// append synced_at column
quotedUpperColNames = append(quotedUpperColNames,
Expand All @@ -876,8 +880,8 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(

insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
quotedUpperColumnName := fmt.Sprintf(`"%s"`, strings.ToUpper(columnName))
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", quotedUpperColumnName))
normalizedColName := SnowflakeIdentifierNormalize(columnName)
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName))
}
// fill in synced_at column
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
Expand All @@ -899,10 +903,13 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
}
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

normalizedpkeyColsArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns {
normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName)
normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName)
pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s",
pkeyColName, pkeyColName))
normalizedPkeyColName, normalizedPkeyColName))
}
// TARGET.<pkey1> = SOURCE.<pkey1> AND TARGET.<pkey2> = SOURCE.<pkey2> ...
pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ")
Expand All @@ -916,9 +923,9 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
}
}

mergeStatement := fmt.Sprintf(mergeStatementSQL, destinationTableIdentifier, toVariantColumnName,
rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")),
mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable),
toVariantColumnName, rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

startTime := time.Now()
Expand Down Expand Up @@ -1045,8 +1052,8 @@ func (c *SnowflakeConnector) generateUpdateStatements(
otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
tmpArray := make([]string, 0, len(otherCols)+2)
for _, colName := range otherCols {
quotedUpperColName := fmt.Sprintf(`"%s"`, strings.ToUpper(colName))
tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", quotedUpperColName, quotedUpperColName))
normalizedColName := SnowflakeIdentifierNormalize(colName)
tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName))
}

// set the synced at column to the current timestamp
Expand Down
21 changes: 21 additions & 0 deletions flow/connectors/utils/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"fmt"
"strings"
"unicode"
)

func QuoteIdentifier(identifier string) string {
Expand All @@ -28,3 +29,23 @@ func ParseSchemaTable(tableName string) (*SchemaTable, error) {

return &SchemaTable{schema, table}, nil
}

// I think these only work with ASCII?
func IsUpper(s string) bool {
for _, r := range s {
if !unicode.IsUpper(r) && unicode.IsLetter(r) {
return false
}
}
return true
}

// I think these only work with ASCII?
func IsLower(s string) bool {
for _, r := range s {
if !unicode.IsLower(r) && unicode.IsLetter(r) {
return false
}
}
return true
}
Loading
Loading