Skip to content

Commit

Permalink
refactored SF to standalone merge generator
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Jan 8, 2024
1 parent 311a65d commit 40e56d4
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 227 deletions.
208 changes: 208 additions & 0 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package connsnowflake

import (
"fmt"
"strings"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
)

type mergeStmtGenerator struct {
rawTableName string
// destination table name, used to retrieve records from raw table
dstTableName string
// last synced batchID.
syncBatchID int64
// last normalized batchID.
normalizeBatchID int64
// the schema of the table to merge into
normalizedTableSchema *protos.TableSchema
// array of toast column combinations that are unchanged
unchangedToastColumns []string
// _PEERDB_IS_DELETED and _SYNCED_AT columns
peerdbCols *protos.PeerDBColumns
}

func (c *mergeStmtGenerator) generateMergeStmt() (string, error) {
parsedDstTable, _ := utils.ParseSchemaTable(c.dstTableName)
columnNames := utils.TableSchemaColumnNames(c.normalizedTableSchema)

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(c.normalizedTableSchema))
err := utils.IterColumnsError(c.normalizedTableSchema, func(columnName, genericColumnType string) error {
qvKind := qvalue.QValueKind(genericColumnType)
sfType, err := qValueKindToSnowflakeType(qvKind)
if err != nil {
return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err)
}

targetColumnName := SnowflakeIdentifierNormalize(columnName)
switch qvalue.QValueKind(genericColumnType) {
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+
"AS %s,", toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindGeography:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,",
toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindGeometry:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,",
toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindJSON:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,",
toVariantColumnName, columnName, targetColumnName))
// TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types
// case model.ColumnTypeTime:
// flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+
// "Microseconds*1000) "+
// "AS %s,", toVariantColumnName, columnName, columnName))
default:
if qvKind == qvalue.QValueKindNumeric {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,",
toVariantColumnName, columnName, sfType, targetColumnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,",
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
return nil
})
if err != nil {
return "", err
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",")

quotedUpperColNames := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName))
}
// append synced_at column
quotedUpperColNames = append(quotedUpperColNames,
fmt.Sprintf(`"%s"`, strings.ToUpper(c.peerdbCols.SyncedAtColName)),
)

insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",")

insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
normalizedColName := SnowflakeIdentifierNormalize(columnName)
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName))
}
// fill in synced_at column
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
insertValuesSQL := strings.Join(insertValuesSQLArray, ",")
updateStatementsforToastCols := c.generateUpdateStatements(columnNames)

// handling the case when an insert and delete happen in the same batch, with updates in the middle
// with soft-delete, we want the row to be in the destination with SOFT_DELETE true
// the current merge statement doesn't do that, so we add another case to insert the DeleteRecord
if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") {
softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames,
c.peerdbCols.SoftDeleteColName), ",")
softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE"
updateStatementsforToastCols = append(updateStatementsforToastCols,
fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)",
softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL))
}
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

normalizedpkeyColsArray := make([]string, 0, len(c.normalizedTableSchema.PrimaryKeyColumns))
pkeySelectSQLArray := make([]string, 0, len(c.normalizedTableSchema.PrimaryKeyColumns))
for _, pkeyColName := range c.normalizedTableSchema.PrimaryKeyColumns {
normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName)
normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName)
pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s",
normalizedPkeyColName, normalizedPkeyColName))
}
// TARGET.<pkey1> = SOURCE.<pkey1> AND TARGET.<pkey2> = SOURCE.<pkey2> ...
pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ")

deletePart := "DELETE"
if c.peerdbCols.SoftDelete {
colName := c.peerdbCols.SoftDeleteColName
deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName)
if c.peerdbCols.SyncedAtColName != "" {
deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, c.peerdbCols.SyncedAtColName)
}
}

mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable),
toVariantColumnName, c.rawTableName, c.normalizeBatchID, c.syncBatchID, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

return mergeStatement, nil
}

/*
This function generates UPDATE statements for a MERGE operation based on the provided inputs.
Inputs:
1. allCols: An array of all column names.
2. unchangedToastCols: An array capturing unique sets of unchanged toast column groups.
3. softDeleteCol: just set to false in the case we see an insert after a soft-deleted column
4. syncedAtCol: set to the CURRENT_TIMESTAMP
Algorithm:
1. Iterate over each unique set of unchanged toast column groups.
2. For each group, split it into individual column names.
3. Calculate the other columns by finding the set difference between allCols and the unchanged columns.
4. Generate an update statement for the current group by setting the appropriate conditions
and updating the other columns.
- The condition includes checking if the _PEERDB_RECORD_TYPE is not 2 (not a DELETE) and if the
_PEERDB_UNCHANGED_TOAST_COLUMNS match the current group.
- The update sets the other columns to their corresponding values
from the SOURCE table. It doesn't set (make null the Unchanged toast columns.
5. Append the update statement to the list of generated statements.
6. Repeat steps 1-5 for each unique set of unchanged toast column groups.
7. Return the list of generated update statements.
*/
func (c *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string {
updateStmts := make([]string, 0, len(c.unchangedToastColumns))

for _, cols := range c.unchangedToastColumns {
unchangedColsArray := strings.Split(cols, ",")
otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
tmpArray := make([]string, 0, len(otherCols)+2)
for _, colName := range otherCols {
normalizedColName := SnowflakeIdentifierNormalize(colName)
tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName))
}

// set the synced at column to the current timestamp
if c.peerdbCols.SyncedAtColName != "" {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`,
c.peerdbCols.SyncedAtColName))
}
// set soft-deleted to false, tackles insert after soft-delete
if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`,
c.peerdbCols.SoftDeleteColName))
}

ssep := strings.Join(tmpArray, ", ")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
(SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s'
THEN UPDATE SET %s `, cols, ssep)
updateStmts = append(updateStmts, updateStmt)

// generates update statements for the case where updates and deletes happen in the same branch
// the backfill has happened from the pull side already, so treat the DeleteRecord as an update
// and then set soft-delete to true.
if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") {
tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`,
c.peerdbCols.SoftDeleteColName))
ssep := strings.Join(tmpArray, ", ")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
(SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s'
THEN UPDATE SET %s `, cols, ssep)
updateStmts = append(updateStmts, updateStmt)
}
}
return updateStmts
}
88 changes: 77 additions & 11 deletions flow/connectors/snowflake/merge_stmt_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,66 @@ import (
"reflect"
"strings"
"testing"

"github.com/PeerDB-io/peer-flow/generated/protos"
)

func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) {
allCols := []string{"col1", "col2", "col3"}
unchangedToastCols := []string{""}

expected := []string{
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS=''
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`,
}
mergeGen := &mergeStmtGenerator{
unchangedToastColumns: unchangedToastCols,
peerdbCols: &protos.PeerDBColumns{
SoftDelete: false,
SyncedAtColName: "_PEERDB_SYNCED_AT",
SoftDeleteColName: "_PEERDB_SOFT_DELETE",
},
}
result := mergeGen.generateUpdateStatements(allCols)

for i := range expected {
expected[i] = removeSpacesTabsNewlines(expected[i])
result[i] = removeSpacesTabsNewlines(result[i])
}

if !reflect.DeepEqual(result, expected) {
t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result)
}
}

func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) {
c := &SnowflakeConnector{}
allCols := []string{"col1", "col2", "col3"}
unchangedToastCols := []string{"", "col2,col3", "col2", "col3"}

expected := []string{
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS=''
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3'
THEN UPDATE SET "COL1" = SOURCE."COL1",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`,
}
mergeGen := &mergeStmtGenerator{
unchangedToastColumns: unchangedToastCols,
peerdbCols: &protos.PeerDBColumns{
SoftDelete: false,
SyncedAtColName: "_PEERDB_SYNCED_AT",
SoftDeleteColName: "_PEERDB_SOFT_DELETE",
},
}
result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols)
result := mergeGen.generateUpdateStatements(allCols)

for i := range expected {
expected[i] = removeSpacesTabsNewlines(expected[i])
Expand All @@ -37,17 +75,45 @@ func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) {
}
}

func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) {
c := &SnowflakeConnector{}
func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) {
allCols := []string{"col1", "col2", "col3"}
unchangedToastCols := []string{""}
unchangedToastCols := []string{"", "col2,col3", "col2", "col3"}

expected := []string{
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS=''
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3",
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS=''
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3'
THEN UPDATE SET "COL1" = SOURCE."COL1",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3'
THEN UPDATE SET "COL1" = SOURCE."COL1",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`,
`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3'
THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2",
"_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`,
}
mergeGen := &mergeStmtGenerator{
unchangedToastColumns: unchangedToastCols,
peerdbCols: &protos.PeerDBColumns{
SoftDelete: true,
SyncedAtColName: "_PEERDB_SYNCED_AT",
SoftDeleteColName: "_PEERDB_SOFT_DELETE",
},
}
result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols)
result := mergeGen.generateUpdateStatements(allCols)

for i := range expected {
expected[i] = removeSpacesTabsNewlines(expected[i])
Expand Down
Loading

0 comments on commit 40e56d4

Please sign in to comment.