diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 4f5a8ca8d9..8a8181ba61 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -209,6 +209,8 @@ services: DATABASE_URL: postgres://postgres:postgres@catalog:5432/postgres PEERDB_FLOW_SERVER_HTTP: http://flow_api:8113 PEERDB_PASSWORD: + NEXTAUTH_SECRET: __changeme__ + NEXTAUTH_URL: http://localhost:3000 depends_on: - flow-api diff --git a/docker-compose.yml b/docker-compose.yml index d6a4fbb127..7ac7e19bb5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -174,6 +174,8 @@ services: <<: *catalog-config DATABASE_URL: postgres://postgres:postgres@catalog:5432/postgres PEERDB_FLOW_SERVER_HTTP: http://flow_api:8113 + NEXTAUTH_SECRET: __changeme__ + NEXTAUTH_URL: http://localhost:3000 depends_on: - flow-api diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 25693f5ee1..df8e288215 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -54,6 +54,7 @@ func (a *FlowableActivity) CheckConnection( ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) dstConn, err := connectors.GetCDCSyncConnector(ctx, config.Peer) if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowName, err) return nil, fmt.Errorf("failed to get connector: %w", err) } defer connectors.CloseConnector(dstConn) diff --git a/flow/activities/snapshot_activity.go b/flow/activities/snapshot_activity.go index 3abd39c588..4fa945b896 100644 --- a/flow/activities/snapshot_activity.go +++ b/flow/activities/snapshot_activity.go @@ -52,12 +52,22 @@ func (a *SnapshotActivity) SetupReplication( replicationErr := make(chan error) defer close(replicationErr) + closeConnectionForError := func(err error) { + slog.ErrorContext(ctx, "failed to setup replication", slog.Any("error", err)) + a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + // it is important to close the connection here as it is not closed in CloseSlotKeepAlive + connCloseErr := conn.Close() + if connCloseErr != nil { + slog.ErrorContext(ctx, "failed to close connection", slog.Any("error", connCloseErr)) + } + } + // This now happens in a goroutine go func() { pgConn := conn.(*connpostgres.PostgresConnector) err = pgConn.SetupReplication(slotSignal, config) if err != nil { - slog.ErrorContext(ctx, "failed to setup replication", slog.Any("error", err)) + closeConnectionForError(err) replicationErr <- err return } @@ -69,12 +79,12 @@ func (a *SnapshotActivity) SetupReplication( case slotInfo = <-slotSignal.SlotCreated: slog.InfoContext(ctx, fmt.Sprintf("slot '%s' created", slotInfo.SlotName)) case err := <-replicationErr: - a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + closeConnectionForError(err) return nil, fmt.Errorf("failed to setup replication: %w", err) } if slotInfo.Err != nil { - a.Alerter.LogFlowError(ctx, config.FlowJobName, slotInfo.Err) + closeConnectionForError(slotInfo.Err) return nil, fmt.Errorf("slot error: %w", slotInfo.Err) } diff --git a/flow/cmd/mirror_status.go b/flow/cmd/mirror_status.go index 45c2a134e3..c5743ad575 100644 --- a/flow/cmd/mirror_status.go +++ b/flow/cmd/mirror_status.go @@ -17,8 +17,10 @@ func (h *FlowRequestHandler) MirrorStatus( ctx context.Context, req *protos.MirrorStatusRequest, ) (*protos.MirrorStatusResponse, error) { + slog.Info("Mirror status endpoint called", slog.String(string(shared.FlowNameKey), req.FlowJobName)) cdcFlow, err := h.isCDCFlow(ctx, req.FlowJobName) if err != nil { + slog.Error(fmt.Sprintf("unable to query flow: %s", err.Error())) return &protos.MirrorStatusResponse{ ErrorMessage: fmt.Sprintf("unable to query flow: %s", err.Error()), }, nil @@ -73,6 +75,7 @@ func (h *FlowRequestHandler) CDCFlowStatus( ctx context.Context, req *protos.MirrorStatusRequest, ) (*protos.CDCMirrorStatus, error) { + slog.Info("CDC mirror status endpoint called", slog.String(string(shared.FlowNameKey), req.FlowJobName)) config, err := h.getFlowConfigFromCatalog(req.FlowJobName) if err != nil { return nil, err @@ -80,21 +83,11 @@ func (h *FlowRequestHandler) CDCFlowStatus( var initialCopyStatus *protos.SnapshotStatus - cloneJobNames, err := h.getCloneTableFlowNames(ctx, req.FlowJobName) + cloneStatuses, err := h.cloneTableSummary(ctx, req.FlowJobName) if err != nil { return nil, err } - cloneStatuses := []*protos.CloneTableSummary{} - for _, cloneJobName := range cloneJobNames { - cloneStatus, err := h.cloneTableSummary(ctx, cloneJobName) - if err != nil { - return nil, err - } - - cloneStatuses = append(cloneStatuses, cloneStatus) - } - initialCopyStatus = &protos.SnapshotStatus{ Clones: cloneStatuses, } @@ -108,70 +101,103 @@ func (h *FlowRequestHandler) CDCFlowStatus( func (h *FlowRequestHandler) cloneTableSummary( ctx context.Context, flowJobName string, -) (*protos.CloneTableSummary, error) { - cfg := h.getQRepConfigFromCatalog(flowJobName) - res := &protos.CloneTableSummary{ - FlowJobName: flowJobName, - TableName: cfg.DestinationTableIdentifier, - } - +) ([]*protos.CloneTableSummary, error) { q := ` SELECT - MIN(start_time) AS StartTime, + qp.flow_name, + qr.config_proto, + MIN(qp.start_time) AS StartTime, COUNT(*) AS NumPartitionsTotal, - COUNT(CASE WHEN end_time IS NOT NULL THEN 1 END) AS NumPartitionsCompleted, - SUM(rows_in_partition) FILTER (WHERE end_time IS NOT NULL) AS NumRowsSynced, - AVG(EXTRACT(EPOCH FROM (end_time - start_time)) * 1000) FILTER (WHERE end_time IS NOT NULL) AS AvgTimePerPartitionMs + COUNT(CASE WHEN qp.end_time IS NOT NULL THEN 1 END) AS NumPartitionsCompleted, + SUM(qp.rows_in_partition) FILTER (WHERE qp.end_time IS NOT NULL) AS NumRowsSynced, + AVG(EXTRACT(EPOCH FROM (qp.end_time - qp.start_time)) * 1000) FILTER (WHERE qp.end_time IS NOT NULL) AS AvgTimePerPartitionMs FROM - peerdb_stats.qrep_partitions + peerdb_stats.qrep_partitions qp + JOIN + peerdb_stats.qrep_runs qr + ON + qp.flow_name = qr.flow_name WHERE - flow_name = $1; + qp.flow_name ILIKE $1 + GROUP BY + qp.flow_name, qr.config_proto; ` + var flowName pgtype.Text + var configBytes []byte var startTime pgtype.Timestamp var numPartitionsTotal pgtype.Int8 var numPartitionsCompleted pgtype.Int8 var numRowsSynced pgtype.Int8 var avgTimePerPartitionMs pgtype.Float8 - err := h.pool.QueryRow(ctx, q, flowJobName).Scan( - &startTime, - &numPartitionsTotal, - &numPartitionsCompleted, - &numRowsSynced, - &avgTimePerPartitionMs, - ) + rows, err := h.pool.Query(ctx, q, "clone_"+flowJobName+"_%") if err != nil { - return nil, fmt.Errorf("unable to query qrep partition - %s: %w", flowJobName, err) + slog.Error(fmt.Sprintf("unable to query initial load partition - %s: %s", flowJobName, err.Error())) + return nil, fmt.Errorf("unable to query initial load partition - %s: %w", flowJobName, err) } - if startTime.Valid { - res.StartTime = timestamppb.New(startTime.Time) - } + defer rows.Close() - if numPartitionsTotal.Valid { - res.NumPartitionsTotal = int32(numPartitionsTotal.Int64) - } + cloneStatuses := []*protos.CloneTableSummary{} + for rows.Next() { + if err := rows.Scan( + &flowName, + &configBytes, + &startTime, + &numPartitionsTotal, + &numPartitionsCompleted, + &numRowsSynced, + &avgTimePerPartitionMs, + ); err != nil { + return nil, fmt.Errorf("unable to scan initial load partition - %s: %w", flowJobName, err) + } - if numPartitionsCompleted.Valid { - res.NumPartitionsCompleted = int32(numPartitionsCompleted.Int64) - } + var res protos.CloneTableSummary - if numRowsSynced.Valid { - res.NumRowsSynced = numRowsSynced.Int64 - } + if flowName.Valid { + res.FlowJobName = flowName.String + } + if startTime.Valid { + res.StartTime = timestamppb.New(startTime.Time) + } - if avgTimePerPartitionMs.Valid { - res.AvgTimePerPartitionMs = int64(avgTimePerPartitionMs.Float64) - } + if numPartitionsTotal.Valid { + res.NumPartitionsTotal = int32(numPartitionsTotal.Int64) + } - return res, nil + if numPartitionsCompleted.Valid { + res.NumPartitionsCompleted = int32(numPartitionsCompleted.Int64) + } + + if numRowsSynced.Valid { + res.NumRowsSynced = numRowsSynced.Int64 + } + + if avgTimePerPartitionMs.Valid { + res.AvgTimePerPartitionMs = int64(avgTimePerPartitionMs.Float64) + } + + if configBytes != nil { + var config protos.QRepConfig + if err := proto.Unmarshal(configBytes, &config); err != nil { + slog.Error(fmt.Sprintf("unable to unmarshal config: %s", err.Error())) + return nil, fmt.Errorf("unable to unmarshal config: %w", err) + } + res.TableName = config.DestinationTableIdentifier + } + + cloneStatuses = append(cloneStatuses, &res) + + } + return cloneStatuses, nil } func (h *FlowRequestHandler) QRepFlowStatus( ctx context.Context, req *protos.MirrorStatusRequest, ) (*protos.QRepMirrorStatus, error) { + slog.Info("QRep Flow status endpoint called", slog.String(string(shared.FlowNameKey), req.FlowJobName)) partitionStatuses, err := h.getPartitionStatuses(ctx, req.FlowJobName) if err != nil { slog.Error(fmt.Sprintf("unable to query qrep partition - %s: %s", req.FlowJobName, err.Error())) @@ -240,11 +266,13 @@ func (h *FlowRequestHandler) getFlowConfigFromCatalog( err = h.pool.QueryRow(context.Background(), "SELECT config_proto FROM flows WHERE name = $1", flowJobName).Scan(&configBytes) if err != nil { + slog.Error(fmt.Sprintf("unable to query flow config from catalog: %s", err.Error())) return nil, fmt.Errorf("unable to query flow config from catalog: %w", err) } err = proto.Unmarshal(configBytes, &config) if err != nil { + slog.Error(fmt.Sprintf("unable to unmarshal flow config: %s", err.Error())) return nil, fmt.Errorf("unable to unmarshal flow config: %w", err) } @@ -296,6 +324,7 @@ func (h *FlowRequestHandler) isCDCFlow(ctx context.Context, flowJobName string) var query pgtype.Text err := h.pool.QueryRow(ctx, "SELECT query_string FROM flows WHERE name = $1", flowJobName).Scan(&query) if err != nil { + slog.Error(fmt.Sprintf("unable to query flow: %s", err.Error())) return false, fmt.Errorf("unable to query flow: %w", err) } @@ -306,36 +335,16 @@ func (h *FlowRequestHandler) isCDCFlow(ctx context.Context, flowJobName string) return false, nil } -func (h *FlowRequestHandler) getCloneTableFlowNames(ctx context.Context, flowJobName string) ([]string, error) { - q := "SELECT flow_name FROM peerdb_stats.qrep_runs WHERE flow_name ILIKE $1" - rows, err := h.pool.Query(ctx, q, "clone_"+flowJobName+"_%") - if err != nil { - return nil, fmt.Errorf("unable to getCloneTableFlowNames: %w", err) - } - defer rows.Close() - - flowNames := []string{} - for rows.Next() { - var name pgtype.Text - if err := rows.Scan(&name); err != nil { - return nil, fmt.Errorf("unable to scan flow row: %w", err) - } - if name.Valid { - flowNames = append(flowNames, name.String) - } - } - - return flowNames, nil -} - func (h *FlowRequestHandler) getWorkflowStatus(ctx context.Context, workflowID string) (*protos.FlowStatus, error) { res, err := h.temporalClient.QueryWorkflow(ctx, workflowID, "", shared.FlowStatusQuery) if err != nil { + slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error())) return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) } var state *protos.FlowStatus err = res.Get(&state) if err != nil { + slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error())) return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) } return state, nil @@ -348,6 +357,7 @@ func (h *FlowRequestHandler) updateWorkflowStatus( ) error { _, err := h.temporalClient.UpdateWorkflow(ctx, workflowID, "", shared.FlowStatusUpdate, state) if err != nil { + slog.Error(fmt.Sprintf("failed to update state in workflow with ID %s: %s", workflowID, err.Error())) return fmt.Errorf("failed to update state in workflow with ID %s: %w", workflowID, err) } return nil diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index fab3fe5a5e..dd1f06bbde 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -84,23 +84,45 @@ func (h *FlowRequestHandler) GetTablesInSchema( } defer peerPool.Close() - rows, err := peerPool.Query(ctx, "SELECT table_name "+ - "FROM information_schema.tables "+ - "WHERE table_schema = $1 AND table_type = 'BASE TABLE';", req.SchemaName) + rows, err := peerPool.Query(ctx, `SELECT DISTINCT ON (t.relname) + t.relname, + CASE + WHEN c.constraint_type = 'PRIMARY KEY' OR t.relreplident = 'i' OR t.relreplident = 'f' THEN true + ELSE false + END AS can_mirror + FROM + information_schema.table_constraints c + RIGHT JOIN + pg_class t ON c.table_name = t.relname + WHERE + t.relnamespace::regnamespace::text = $1 + AND + t.relkind = 'r' + ORDER BY + t.relname, + can_mirror DESC;`, req.SchemaName) if err != nil { return &protos.SchemaTablesResponse{Tables: nil}, err } defer rows.Close() - var tables []string + var tables []*protos.TableResponse for rows.Next() { var table pgtype.Text - err := rows.Scan(&table) + var hasPkeyOrReplica pgtype.Bool + err := rows.Scan(&table, &hasPkeyOrReplica) if err != nil { return &protos.SchemaTablesResponse{Tables: nil}, err } + canMirror := false + if hasPkeyOrReplica.Valid && hasPkeyOrReplica.Bool { + canMirror = true + } - tables = append(tables, table.String) + tables = append(tables, &protos.TableResponse{ + TableName: table.String, + CanMirror: canMirror, + }) } return &protos.SchemaTablesResponse{Tables: tables}, nil } diff --git a/flow/connectors/bigquery/avro_transform_test.go b/flow/connectors/bigquery/avro_transform_test.go new file mode 100644 index 0000000000..0a9332fc87 --- /dev/null +++ b/flow/connectors/bigquery/avro_transform_test.go @@ -0,0 +1,44 @@ +package connbigquery + +import ( + "reflect" + "testing" + + "cloud.google.com/go/bigquery" +) + +func TestAvroTransform(t *testing.T) { + dstSchema := &bigquery.Schema{ + &bigquery.FieldSchema{ + Name: "col1", + Type: bigquery.GeographyFieldType, + }, + &bigquery.FieldSchema{ + Name: "col2", + Type: bigquery.JSONFieldType, + }, + &bigquery.FieldSchema{ + Name: "col3", + Type: bigquery.DateFieldType, + }, + &bigquery.FieldSchema{ + Name: "camelCol4", + Type: bigquery.StringFieldType, + }, + &bigquery.FieldSchema{ + Name: "sync_col", + Type: bigquery.TimestampFieldType, + }, + } + + expectedTransformCols := []string{ + "ST_GEOGFROMTEXT(`col1`) AS `col1`", + "PARSE_JSON(`col2`,wide_number_mode=>'round') AS `col2`", + "CAST(`col3` AS DATE) AS `col3`", + "`camelCol4`", + } + transformedCols := getTransformedColumns(dstSchema, "sync_col", "del_col") + if !reflect.DeepEqual(transformedCols, expectedTransformCols) { + t.Errorf("Transform SQL is not correct. Got: %v", transformedCols) + } +} diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go similarity index 93% rename from flow/connectors/bigquery/merge_statement_generator.go rename to flow/connectors/bigquery/merge_stmt_generator.go index 7e6d94dc9b..d93e31ad89 100644 --- a/flow/connectors/bigquery/merge_statement_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -47,7 +47,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string { var castStmt string shortCol := m.shortColumn[colName] switch qvalue.QValueKind(colType) { - case qvalue.QValueKindJSON: + case qvalue.QValueKindJSON, qvalue.QValueKindHStore: // if the type is JSON, then just extract JSON castStmt = fmt.Sprintf("CAST(PARSE_JSON(JSON_VALUE(_peerdb_data, '$.%s'),wide_number_mode=>'round') AS %s) AS `%s`", colName, bqType, shortCol) @@ -139,8 +139,7 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s insertColumnsSQL := csep + fmt.Sprintf(", `%s`", m.peerdbCols.SyncedAtColName) insertValuesSQL := shortCsep + ",CURRENT_TIMESTAMP" - updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, - unchangedToastColumns, m.peerdbCols) + updateStatementsforToastCols := m.generateUpdateStatements(pureColNames) if m.peerdbCols.SoftDelete { softDeleteInsertColumnsSQL := insertColumnsSQL + fmt.Sprintf(",`%s`", m.peerdbCols.SoftDeleteColName) softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" @@ -210,14 +209,17 @@ and updating the other columns (not the unchanged toast columns) 6. Repeat steps 1-5 for each unique unchanged toast column group. 7. Return the list of generated update statements. */ -func (m *mergeStmtGenerator) generateUpdateStatements( - allCols []string, - unchangedToastCols []string, - peerdbCols *protos.PeerDBColumns, -) []string { - updateStmts := make([]string, 0, len(unchangedToastCols)) +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) - for _, cols := range unchangedToastCols { + for _, cols := range m.unchangedToastColumns { unchangedColsArray := strings.Split(cols, ",") otherCols := utils.ArrayMinus(allCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)) @@ -226,14 +228,14 @@ func (m *mergeStmtGenerator) generateUpdateStatements( } // set the synced at column to the current timestamp - if peerdbCols.SyncedAtColName != "" { + if m.peerdbCols.SyncedAtColName != "" { tmpArray = append(tmpArray, fmt.Sprintf("`%s`=CURRENT_TIMESTAMP", - peerdbCols.SyncedAtColName)) + m.peerdbCols.SyncedAtColName)) } // set soft-deleted to false, tackles insert after soft-delete - if peerdbCols.SoftDeleteColName != "" { + if handleSoftDelete { tmpArray = append(tmpArray, fmt.Sprintf("`%s`=FALSE", - peerdbCols.SoftDeleteColName)) + m.peerdbCols.SoftDeleteColName)) } ssep := strings.Join(tmpArray, ",") @@ -245,9 +247,9 @@ func (m *mergeStmtGenerator) generateUpdateStatements( // generates update statements for the case where updates and deletes happen in the same branch // the backfill has happened from the pull side already, so treat the DeleteRecord as an update // and then set soft-delete to true. - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { + if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf("`%s`=TRUE", peerdbCols.SoftDeleteColName)) + fmt.Sprintf("`%s`=TRUE", m.peerdbCols.SoftDeleteColName)) ssep := strings.Join(tmpArray, ",") updateStmt := fmt.Sprintf(`WHEN MATCHED AND _rt=2 AND _ut='%s' diff --git a/flow/connectors/bigquery/merge_stmt_generator_test.go b/flow/connectors/bigquery/merge_stmt_generator_test.go index 141b3999b7..cc49b17cbd 100644 --- a/flow/connectors/bigquery/merge_stmt_generator_test.go +++ b/flow/connectors/bigquery/merge_stmt_generator_test.go @@ -2,77 +2,67 @@ package connbigquery import ( "reflect" - "strings" "testing" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) -func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} m := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, shortColumn: map[string]string{ "col1": "_c0", "col2": "_c1", "col3": "_c2", }, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, } - allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ - "WHEN MATCHED AND _rt!=2 AND _ut=''" + - " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='' " + - "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + - "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", - "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", "WHEN MATCHED AND _rt!=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", - "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "AND _ut=''" + + "THEN UPDATE SET " + + "`col1`=_d._c0," + + "`col2`=_d._c1," + + "`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", } - result := m.generateUpdateStatements(allCols, unchangedToastCols, &protos.PeerDBColumns{ - SoftDelete: true, - SoftDeleteColName: "deleted", - SyncedAtColName: "synced_at", - }) + result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { - t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } -func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} m := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, shortColumn: map[string]string{ "col1": "_c0", "col2": "_c1", "col3": "_c2", }, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, } - allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{""} expected := []string{ "WHEN MATCHED AND _rt!=2 " + @@ -89,26 +79,115 @@ func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", } - result := m.generateUpdateStatements(allCols, unchangedToastCols, - &protos.PeerDBColumns{ - SoftDelete: true, + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, SoftDeleteColName: "deleted", SyncedAtColName: "synced_at", - }) + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP", + } + + result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { - t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) } } -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s +func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='' " + + "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + + "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + } + + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + } } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index d6df8fdb6e..153d4fe1ae 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -110,9 +110,9 @@ func (s *QRepAvroSyncMethod) SyncRecords( return numRecords, nil } -func getTransformedColumns(dstTableMetadata *bigquery.TableMetadata, syncedAtCol string, softDeleteCol string) []string { - transformedColumns := make([]string, 0, len(dstTableMetadata.Schema)) - for _, col := range dstTableMetadata.Schema { +func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softDeleteCol string) []string { + transformedColumns := make([]string, 0, len(*dstSchema)) + for _, col := range *dstSchema { if col.Name == syncedAtCol || col.Name == softDeleteCol { continue } @@ -174,7 +174,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( ) bqClient := s.connector.client - transformedColumns := getTransformedColumns(dstTableMetadata, syncedAtCol, softDeleteCol) + transformedColumns := getTransformedColumns(&dstTableMetadata.Schema, syncedAtCol, softDeleteCol) selector := strings.Join(transformedColumns, ", ") if softDeleteCol != "" { // PeerDB column diff --git a/flow/connectors/bigquery/qvalue_convert.go b/flow/connectors/bigquery/qvalue_convert.go index 654c5cdc24..7e98eabd15 100644 --- a/flow/connectors/bigquery/qvalue_convert.go +++ b/flow/connectors/bigquery/qvalue_convert.go @@ -24,7 +24,7 @@ func qValueKindToBigQueryType(colType string) bigquery.FieldType { case qvalue.QValueKindString: return bigquery.StringFieldType // json also is stored as string for now - case qvalue.QValueKindJSON: + case qvalue.QValueKindJSON, qvalue.QValueKindHStore: return bigquery.JSONFieldType // time related case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ: diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index b51741aff6..0a99bce668 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -5,18 +5,15 @@ import ( "fmt" "log" "regexp" - "slices" "strings" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" - "golang.org/x/exp/maps" ) const ( @@ -599,215 +596,6 @@ func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, sync return resultMap, nil } -func (c *PostgresConnector) generateNormalizeStatements(destinationTableIdentifier string, - unchangedToastColumns []string, rawTableIdentifier string, supportsMerge bool, - peerdbCols *protos.PeerDBColumns, -) []string { - if supportsMerge { - return []string{c.generateMergeStatement(destinationTableIdentifier, unchangedToastColumns, - rawTableIdentifier, peerdbCols)} - } - c.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE") - c.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") - return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier, peerdbCols) -} - -func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifier string, - rawTableIdentifier string, peerdbCols *protos.PeerDBColumns, -) []string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnCount := utils.TableSchemaColumns(normalizedTableSchema) - columnNames := make([]string, 0, columnCount) - flattenedCastsSQLArray := make([]string, 0, columnCount) - primaryKeyColumnCasts := make(map[string]string) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") - updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - utils.IterColumns(normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) - }) - updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") - deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) - } - deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") - deletePart := fmt.Sprintf( - "DELETE FROM %s USING", - parsedDstTable.String()) - - if peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s" = TRUE`, - parsedDstTable.String(), peerdbCols.SoftDeleteColName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - deletePart += " FROM" - } - fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, - strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema, - rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, - strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) - fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema, - rawTableIdentifier, deletePart, deleteWhereClauseSQL) - - return []string{fallbackUpsertStatement, fallbackDeleteStatement} -} - -func (c *PostgresConnector) generateMergeStatement( - destinationTableIdentifier string, - unchangedToastColumns []string, - rawTableIdentifier string, - peerdbCols *protos.PeerDBColumns, -) string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnNames := utils.TableSchemaColumnNames(normalizedTableSchema) - for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) - } - - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - primaryKeyColumnCasts := make(map[string]string) - primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) - } - - updateStatementsforToastCols := c.generateUpdateStatement(columnNames, unchangedToastColumns, peerdbCols) - // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, peerdbCols.SyncedAtColName)) - insertColumnsSQL := strings.Join(columnNames, ",") - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") - - if peerdbCols.SoftDelete { - softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, peerdbCols.SoftDeleteColName)), ","), ",") - softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") - - updateStatementsforToastCols = append(updateStatementsforToastCols, - fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type = 2) THEN INSERT (%s) VALUES(%s)", - softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) - } - updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") - - deletePart := "DELETE" - if peerdbCols.SoftDelete { - colName := peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s" = TRUE`, colName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - } - - mergeStmt := fmt.Sprintf( - mergeStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), - c.metadataSchema, - rawTableIdentifier, - parsedDstTable.String(), - flattenedCastsSQL, - strings.Join(primaryKeySelectSQLArray, " AND "), - insertColumnsSQL, - insertValuesSQL, - updateStringToastCols, - deletePart, - ) - - return mergeStmt -} - -func (c *PostgresConnector) generateUpdateStatement(allCols []string, - unchangedToastColsLists []string, peerdbCols *protos.PeerDBColumns, -) []string { - updateStmts := make([]string, 0, len(unchangedToastColsLists)) - - for _, cols := range unchangedToastColsLists { - unquotedUnchangedColsArray := strings.Split(cols, ",") - unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) - for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) - } - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0, len(otherCols)) - for _, colName := range otherCols { - tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) - } - // set the synced at column to the current timestamp - if peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, - peerdbCols.SyncedAtColName)) - } - // set soft-deleted to false, tackles insert after soft-delete - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, - peerdbCols.SoftDeleteColName)) - } - - ssep := strings.Join(tmpArray, ",") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - - // generates update statements for the case where updates and deletes happen in the same branch - // the backfill has happened from the pull side already, so treat the DeleteRecord as an update - // and then set soft-delete to true. - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s" = TRUE`, peerdbCols.SoftDeleteColName)) - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - } - } - return updateStmts -} - func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { row := c.pool.QueryRow(c.ctx, "SELECT CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END") diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go new file mode 100644 index 0000000000..b541543fe2 --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -0,0 +1,235 @@ +package connpostgres + +import ( + "fmt" + "log/slog" + "slices" + "strings" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "golang.org/x/exp/maps" +) + +type normalizeStmtGenerator struct { + rawTableName string + // destination table name, used to retrieve records from raw table + dstTableName string + // the schema of the table to merge into + normalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + unchangedToastColumns []string + // _PEERDB_IS_DELETED and _SYNCED_AT columns + peerdbCols *protos.PeerDBColumns + // Postgres version 15 introduced MERGE, fallback statements before that + supportsMerge bool + // Postgres metadata schema + metadataSchema string + // to log fallback statement selection + logger slog.Logger +} + +func (n *normalizeStmtGenerator) generateNormalizeStatements() []string { + if n.supportsMerge { + return []string{n.generateMergeStatement()} + } + n.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT+DELETE") + n.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") + if n.peerdbCols.SoftDelete { + n.logger.Warn("soft delete enabled with fallback statements! this combination is unsupported") + } + return n.generateFallbackStatements() +} + +func (n *normalizeStmtGenerator) generateFallbackStatements() []string { + columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) + columnNames := make([]string, 0, columnCount) + flattenedCastsSQLArray := make([]string, 0, columnCount) + primaryKeyColumnCasts := make(map[string]string) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") + updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + }) + updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") + deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + for columnName, columnCast := range primaryKeyColumnCasts { + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, + parsedDstTable.String(), columnName, columnCast)) + } + deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") + deletePart := fmt.Sprintf( + "DELETE FROM %s USING", + parsedDstTable.String()) + + if n.peerdbCols.SoftDelete { + deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, + parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + deletePart += " FROM" + } + fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, + strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema, + n.rawTableName, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, + strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) + fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, + n.rawTableName, deletePart, deleteWhereClauseSQL) + + return []string{fallbackUpsertStatement, fallbackDeleteStatement} +} + +func (n *normalizeStmtGenerator) generateMergeStatement() string { + columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) + for i, columnName := range columnNames { + columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + } + + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + primaryKeyColumnCasts := make(map[string]string) + primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", + columnName, columnName)) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + } + + updateStatementsforToastCols := n.generateUpdateStatements(columnNames) + // append synced_at column + columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) + insertColumnsSQL := strings.Join(columnNames, ",") + // fill in synced_at column + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + + if n.peerdbCols.SoftDelete { + softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, + fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") + + updateStatementsforToastCols = append(updateStatementsforToastCols, + fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type=2) THEN INSERT (%s) VALUES(%s)", + softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) + } + updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") + + deletePart := "DELETE" + if n.peerdbCols.SoftDelete { + colName := n.peerdbCols.SoftDeleteColName + deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + } + + mergeStmt := fmt.Sprintf( + mergeStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), + n.metadataSchema, + n.rawTableName, + parsedDstTable.String(), + flattenedCastsSQL, + strings.Join(primaryKeySelectSQLArray, " AND "), + insertColumnsSQL, + insertValuesSQL, + updateStringToastCols, + deletePart, + ) + + return mergeStmt +} + +func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(n.unchangedToastColumns) + } + return len(n.unchangedToastColumns) + }()) + + for _, cols := range n.unchangedToastColumns { + unquotedUnchangedColsArray := strings.Split(cols, ",") + unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) + for _, unchangedToastCol := range unquotedUnchangedColsArray { + unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) + } + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0, len(otherCols)) + for _, colName := range otherCols { + tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) + } + // set the synced at column to the current timestamp + if n.peerdbCols.SyncedAtColName != "" { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, + n.peerdbCols.SyncedAtColName)) + } + // set soft-deleted to false, tackles insert after soft-delete + if handleSoftDelete { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, + n.peerdbCols.SoftDeleteColName)) + } + + ssep := strings.Join(tmpArray, ",") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s`, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + + // generates update statements for the case where updates and deletes happen in the same branch + // the backfill has happened from the pull side already, so treat the DeleteRecord as an update + // and then set soft-delete to true. + if handleSoftDelete { + tmpArray = append(tmpArray[:len(tmpArray)-1], + fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + } + return updateStmts +} diff --git a/flow/connectors/postgres/normalize_stmt_generator_test.go b/flow/connectors/postgres/normalize_stmt_generator_test.go new file mode 100644 index 0000000000..637c69f2bc --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator_test.go @@ -0,0 +1,148 @@ +package connpostgres + +import ( + "reflect" + "testing" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +func TestGenerateMergeUpdateStatement(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 43a11a2e72..e57abf33f3 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -101,6 +101,7 @@ func (c *PostgresConnector) GetReplPool(ctx context.Context) (*SSHWrappedPostgre pool, err := NewSSHWrappedPostgresPool(ctx, c.replConfig, c.config.SshConfig) if err != nil { + slog.Error("failed to create replication connection pool", slog.Any("error", err)) return nil, fmt.Errorf("failed to create replication connection pool: %w", err) } @@ -461,13 +462,21 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) mergeStatementsBatch := &pgx.Batch{} totalRowsAffected := 0 for _, destinationTableName := range destinationTableNames { - peerdbCols := protos.PeerDBColumns{ - SoftDeleteColName: req.SoftDeleteColName, - SyncedAtColName: req.SyncedAtColName, - SoftDelete: req.SoftDelete, + normalizeStmtGen := &normalizeStmtGenerator{ + rawTableName: rawTableIdentifier, + dstTableName: destinationTableName, + normalizedTableSchema: c.tableSchemaMapping[destinationTableName], + unchangedToastColumns: unchangedToastColsMap[destinationTableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + SoftDelete: req.SoftDelete, + }, + supportsMerge: supportsMerge, + metadataSchema: c.metadataSchema, + logger: c.logger, } - normalizeStatements := c.generateNormalizeStatements(destinationTableName, unchangedToastColsMap[destinationTableName], - rawTableIdentifier, supportsMerge, &peerdbCols) + normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 839845d284..84a837a31a 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -104,14 +104,11 @@ func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescrip cname := fd.Name ctype := postgresOIDToQValueKind(fd.DataTypeOID) if ctype == qvalue.QValueKindInvalid { - var err error - if err != nil { - typeName, ok := qe.customTypeMap[fd.DataTypeOID] - if ok { - ctype = customTypeToQKind(typeName) - } else { - ctype = qvalue.QValueKindString - } + typeName, ok := qe.customTypeMap[fd.DataTypeOID] + if ok { + ctype = customTypeToQKind(typeName) + } else { + ctype = qvalue.QValueKindString } } // there isn't a way to know if a column is nullable or not diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 0037a43fa1..b105250ebd 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -111,6 +111,8 @@ func qValueKindToPostgresType(qvalueKind string) string { return "BYTEA" case qvalue.QValueKindJSON: return "JSONB" + case qvalue.QValueKindHStore: + return "HSTORE" case qvalue.QValueKindUUID: return "UUID" case qvalue.QValueKindTime: @@ -335,12 +337,6 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( default: return qvalue.QValue{}, fmt.Errorf("failed to parse array string: %v", value) } - case qvalue.QValueKindHStore: - hstoreVal, err := value.(pgtype.Hstore).HstoreValue() - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse hstore: %w", err) - } - val = qvalue.QValue{Kind: qvalue.QValueKindHStore, Value: hstoreVal} case qvalue.QValueKindPoint: xCoord := value.(pgtype.Point).P.X yCoord := value.(pgtype.Point).P.Y @@ -399,6 +395,8 @@ func customTypeToQKind(typeName string) qvalue.QValueKind { qValueKind = qvalue.QValueKindGeometry case "geography": qValueKind = qvalue.QValueKindGeography + case "hstore": + qValueKind = qvalue.QValueKindHStore default: qValueKind = qvalue.QValueKindString } diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index c8608c52aa..a82356a721 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -81,17 +81,18 @@ func (swpp *SSHWrappedPostgresPool) connect() error { return } + host := swpp.poolConfig.ConnConfig.Host err = retryWithBackoff(func() error { err = swpp.Ping(swpp.ctx) if err != nil { - slog.Error("Failed to ping pool", slog.Any("error", err)) + slog.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host)) return err } return nil }, 5, 5*time.Second) if err != nil { - slog.Error("Failed to create pool", slog.Any("error", err)) + slog.Error("Failed to create pool", slog.Any("error", err), slog.String("host", host)) } }) diff --git a/flow/connectors/snowflake/avro_transform_test.go b/flow/connectors/snowflake/avro_transform_test.go new file mode 100644 index 0000000000..ffbe896658 --- /dev/null +++ b/flow/connectors/snowflake/avro_transform_test.go @@ -0,0 +1,23 @@ +package connsnowflake + +import "testing" + +func TestAvroTransform(t *testing.T) { + colNames := []string{"col1", "col2", "col3", "camelCol4", "sync_col"} + colTypes := []string{"GEOGRAPHY", "VARIANT", "NUMBER", "STRING", "TIMESTAMP_LTZ"} + + expectedTransform := `TO_GEOGRAPHY($1:"col1"::string, true) AS "COL1",` + + `PARSE_JSON($1:"col2") AS "COL2",` + + `$1:"col3" AS "COL3",` + + `($1:"camelCol4")::STRING AS "camelCol4",` + + `CURRENT_TIMESTAMP AS "SYNC_COL"` + transform, cols := GetTransformSQL(colNames, colTypes, "sync_col") + if transform != expectedTransform { + t.Errorf("Transform SQL is not correct. Got: %v", transform) + } + + expectedCols := `"COL1","COL2","COL3","camelCol4","SYNC_COL"` + if cols != expectedCols { + t.Errorf("Columns are not correct. Got:%v", cols) + } +} diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go new file mode 100644 index 0000000000..684f922dfd --- /dev/null +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -0,0 +1,215 @@ +package connsnowflake + +import ( + "fmt" + "strings" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +type mergeStmtGenerator struct { + rawTableName string + // destination table name, used to retrieve records from raw table + dstTableName string + // last synced batchID. + syncBatchID int64 + // last normalized batchID. + normalizeBatchID int64 + // the schema of the table to merge into + normalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + unchangedToastColumns []string + // _PEERDB_IS_DELETED and _SYNCED_AT columns + peerdbCols *protos.PeerDBColumns +} + +func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { + parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName) + columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema) + + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)) + err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error { + qvKind := qvalue.QValueKind(genericColumnType) + sfType, err := qValueKindToSnowflakeType(qvKind) + if err != nil { + return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) + } + + targetColumnName := SnowflakeIdentifierNormalize(columnName) + switch qvalue.QValueKind(genericColumnType) { + case qvalue.QValueKindBytes, qvalue.QValueKindBit: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ + "AS %s,", toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeography: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeometry: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindJSON: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types + // case model.ColumnTypeTime: + // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ + // "Microseconds*1000) "+ + // "AS %s,", toVariantColumnName, columnName, columnName)) + default: + if qvKind == qvalue.QValueKindNumeric { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,", + toVariantColumnName, columnName, sfType, targetColumnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,", + toVariantColumnName, columnName, sfType, targetColumnName)) + } + } + return nil + }) + if err != nil { + return "", err + } + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",") + + quotedUpperColNames := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName)) + } + // append synced_at column + quotedUpperColNames = append(quotedUpperColNames, + fmt.Sprintf(`"%s"`, strings.ToUpper(m.peerdbCols.SyncedAtColName)), + ) + + insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") + + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + normalizedColName := SnowflakeIdentifierNormalize(columnName) + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName)) + } + // fill in synced_at column + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + insertValuesSQL := strings.Join(insertValuesSQLArray, ",") + updateStatementsforToastCols := m.generateUpdateStatements(columnNames) + + // handling the case when an insert and delete happen in the same batch, with updates in the middle + // with soft-delete, we want the row to be in the destination with SOFT_DELETE true + // the current merge statement doesn't do that, so we add another case to insert the DeleteRecord + if m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") { + softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames, + m.peerdbCols.SoftDeleteColName), ",") + softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" + updateStatementsforToastCols = append(updateStatementsforToastCols, + fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)", + softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) + } + updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") + + normalizedpkeyColsArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { + normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName) + normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName) + pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", + normalizedPkeyColName, normalizedPkeyColName)) + } + // TARGET. = SOURCE. AND TARGET. = SOURCE. ... + pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") + + deletePart := "DELETE" + if m.peerdbCols.SoftDelete { + colName := m.peerdbCols.SoftDeleteColName + deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName) + if m.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, m.peerdbCols.SyncedAtColName) + } + } + + mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable), + toVariantColumnName, m.rawTableName, m.normalizeBatchID, m.syncBatchID, flattenedCastsSQL, + fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")), + pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) + + return mergeStatement, nil +} + +/* +This function generates UPDATE statements for a MERGE operation based on the provided inputs. + +Inputs: +1. allCols: An array of all column names. +2. unchangedToastCols: An array capturing unique sets of unchanged toast column groups. +3. softDeleteCol: just set to false in the case we see an insert after a soft-deleted column +4. syncedAtCol: set to the CURRENT_TIMESTAMP + +Algorithm: +1. Iterate over each unique set of unchanged toast column groups. +2. For each group, split it into individual column names. +3. Calculate the other columns by finding the set difference between allCols and the unchanged columns. +4. Generate an update statement for the current group by setting the appropriate conditions +and updating the other columns. + - The condition includes checking if the _PEERDB_RECORD_TYPE is not 2 (not a DELETE) and if the + _PEERDB_UNCHANGED_TOAST_COLUMNS match the current group. + - The update sets the other columns to their corresponding values + from the SOURCE table. It doesn't set (make null the Unchanged toast columns. + +5. Append the update statement to the list of generated statements. +6. Repeat steps 1-5 for each unique set of unchanged toast column groups. +7. Return the list of generated update statements. +*/ +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) + + for _, cols := range m.unchangedToastColumns { + unchangedColsArray := strings.Split(cols, ",") + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0, len(otherCols)+2) + for _, colName := range otherCols { + normalizedColName := SnowflakeIdentifierNormalize(colName) + tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName)) + } + + // set the synced at column to the current timestamp + if m.peerdbCols.SyncedAtColName != "" { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, + m.peerdbCols.SyncedAtColName)) + } + // set soft-deleted to false, tackles insert after soft-delete + if handleSoftDelete { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, + m.peerdbCols.SoftDeleteColName)) + } + + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + + // generates update statements for the case where updates and deletes happen in the same branch + // the backfill has happened from the pull side already, so treat the DeleteRecord as an update + // and then set soft-delete to true. + if handleSoftDelete { + tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`, + m.peerdbCols.SoftDeleteColName)) + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + } + return updateStmts +} diff --git a/flow/connectors/snowflake/merge_stmt_generator_test.go b/flow/connectors/snowflake/merge_stmt_generator_test.go index c4eb2e973e..f8b70f566a 100644 --- a/flow/connectors/snowflake/merge_stmt_generator_test.go +++ b/flow/connectors/snowflake/merge_stmt_generator_test.go @@ -2,34 +2,104 @@ package connsnowflake import ( "reflect" - "strings" "testing" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" ) +func TestGenerateUpdateStatement(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, + } + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + } + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { - c := &SnowflakeConnector{} allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' THEN UPDATE SET "COL1" = SOURCE."COL1", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, } - result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols) + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { @@ -37,31 +107,52 @@ func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { } } -func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) { - c := &SnowflakeConnector{} +func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{""} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' - THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, } - result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols) + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } - -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s -} diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 83521088d8..47484a06f8 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -292,15 +292,7 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(avroFile *avro.AvroFile, stage return nil } -func (c *SnowflakeConnector) GetCopyTransformation( - dstTableName string, - syncedAtCol string, -) (*CopyInfo, error) { - colNames, colTypes, colsErr := c.getColsFromTable(dstTableName) - if colsErr != nil { - return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr) - } - +func GetTransformSQL(colNames []string, colTypes []string, syncedAtCol string) (string, string) { transformations := make([]string, 0, len(colNames)) columnOrder := make([]string, 0, len(colNames)) for idx, avroColName := range colNames { @@ -337,6 +329,20 @@ func (c *SnowflakeConnector) GetCopyTransformation( } transformationSQL := strings.Join(transformations, ",") columnsSQL := strings.Join(columnOrder, ",") + + return transformationSQL, columnsSQL +} + +func (c *SnowflakeConnector) GetCopyTransformation( + dstTableName string, + syncedAtCol string, +) (*CopyInfo, error) { + colNames, colTypes, colsErr := c.getColsFromTable(dstTableName) + if colsErr != nil { + return nil, fmt.Errorf("failed to get columns from destination table: %w", colsErr) + } + + transformationSQL, columnsSQL := GetTransformSQL(colNames, colTypes, syncedAtCol) return &CopyInfo{transformationSQL, columnsSQL}, nil } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index bb1eb4240c..fe1326ebee 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -638,18 +638,43 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest tableName := destinationTableName // local variable for the closure g.Go(func() error { - rowsAffected, err := c.generateAndExecuteMergeStatement( - gCtx, - tableName, - tableNametoUnchangedToastCols[tableName], - getRawTableIdentifier(req.FlowJobName), - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID, - req) + mergeGen := &mergeStmtGenerator{ + rawTableName: getRawTableIdentifier(req.FlowJobName), + dstTableName: tableName, + syncBatchID: batchIDs.SyncBatchID, + normalizeBatchID: batchIDs.NormalizeBatchID, + normalizedTableSchema: c.tableSchemaMapping[tableName], + unchangedToastColumns: tableNametoUnchangedToastCols[tableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: req.SoftDelete, + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + }, + } + mergeStatement, err := mergeGen.generateMergeStmt() + + startTime := time.Now() + c.logger.Info("[merge] merging records...", slog.String("destTable", tableName)) + + result, err := c.database.ExecContext(gCtx, mergeStatement, tableName) + if err != nil { + return fmt.Errorf("failed to merge records into %s (statement: %s): %w", + tableName, mergeStatement, err) + } + + endTime := time.Now() + c.logger.Info(fmt.Sprintf("[merge] merged records into %s, took: %d seconds", + tableName, endTime.Sub(startTime)/time.Second)) if err != nil { c.logger.Error("[merge] error while normalizing records", slog.Any("error", err)) return err } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected by merge statement for table %s: %w", tableName, err) + } + atomic.AddInt64(&totalRowsAffected, rowsAffected) return nil }) @@ -813,146 +838,6 @@ func getRawTableIdentifier(jobName string) string { return fmt.Sprintf("%s_%s", rawTablePrefix, jobName) } -func (c *SnowflakeConnector) generateAndExecuteMergeStatement( - ctx context.Context, - destinationTableIdentifier string, - unchangedToastColumns []string, - rawTableIdentifier string, - syncBatchID int64, - normalizeBatchID int64, - normalizeReq *model.NormalizeRecordsRequest, -) (int64, error) { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - parsedDstTable, err := utils.ParseSchemaTable(destinationTableIdentifier) - if err != nil { - return 0, fmt.Errorf("unable to parse destination table '%s'", parsedDstTable) - } - columnNames := utils.TableSchemaColumnNames(normalizedTableSchema) - - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - err = utils.IterColumnsError(normalizedTableSchema, func(columnName, genericColumnType string) error { - qvKind := qvalue.QValueKind(genericColumnType) - sfType, err := qValueKindToSnowflakeType(qvKind) - if err != nil { - return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) - } - - targetColumnName := SnowflakeIdentifierNormalize(columnName) - switch qvalue.QValueKind(genericColumnType) { - case qvalue.QValueKindBytes, qvalue.QValueKindBit: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ - "AS %s,", toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindGeography: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindGeometry: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindJSON: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types - // case model.ColumnTypeTime: - // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ - // "Microseconds*1000) "+ - // "AS %s,", toVariantColumnName, columnName, columnName)) - default: - if qvKind == qvalue.QValueKindNumeric { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,", - toVariantColumnName, columnName, sfType, targetColumnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,", - toVariantColumnName, columnName, sfType, targetColumnName)) - } - } - return nil - }) - if err != nil { - return 0, err - } - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",") - - quotedUpperColNames := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName)) - } - // append synced_at column - quotedUpperColNames = append(quotedUpperColNames, - fmt.Sprintf(`"%s"`, strings.ToUpper(normalizeReq.SyncedAtColName)), - ) - - insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") - - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - normalizedColName := SnowflakeIdentifierNormalize(columnName) - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName)) - } - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.Join(insertValuesSQLArray, ",") - updateStatementsforToastCols := c.generateUpdateStatements(normalizeReq.SyncedAtColName, - normalizeReq.SoftDeleteColName, normalizeReq.SoftDelete, - columnNames, unchangedToastColumns) - - // handling the case when an insert and delete happen in the same batch, with updates in the middle - // with soft-delete, we want the row to be in the destination with SOFT_DELETE true - // the current merge statement doesn't do that, so we add another case to insert the DeleteRecord - if normalizeReq.SoftDelete { - softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames, - normalizeReq.SoftDeleteColName), ",") - softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" - updateStatementsforToastCols = append(updateStatementsforToastCols, - fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)", - softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) - } - updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - - normalizedpkeyColsArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns { - normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName) - normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName) - pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", - normalizedPkeyColName, normalizedPkeyColName)) - } - // TARGET. = SOURCE. AND TARGET. = SOURCE. ... - pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") - - deletePart := "DELETE" - if normalizeReq.SoftDelete { - colName := normalizeReq.SoftDeleteColName - deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName) - if normalizeReq.SyncedAtColName != "" { - deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, normalizeReq.SyncedAtColName) - } - } - - mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable), - toVariantColumnName, rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL, - fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")), - pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) - - startTime := time.Now() - c.logger.Info("[merge] merging records...", slog.String("destTable", destinationTableIdentifier)) - - result, err := c.database.ExecContext(ctx, mergeStatement, destinationTableIdentifier) - if err != nil { - return 0, fmt.Errorf("failed to merge records into %s (statement: %s): %w", - destinationTableIdentifier, mergeStatement, err) - } - - endTime := time.Now() - c.logger.Info(fmt.Sprintf("[merge] merged records into %s, took: %d seconds", - destinationTableIdentifier, endTime.Sub(startTime)/time.Second)) - - return result.RowsAffected() -} - func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { var result pgtype.Bool err := c.database.QueryRowContext(c.ctx, @@ -1039,75 +924,6 @@ func (c *SnowflakeConnector) createPeerDBInternalSchema(createSchemaTx *sql.Tx) return nil } -/* -This function generates UPDATE statements for a MERGE operation based on the provided inputs. - -Inputs: -1. allCols: An array of all column names. -2. unchangedToastCols: An array capturing unique sets of unchanged toast column groups. -3. softDeleteCol: just set to false in the case we see an insert after a soft-deleted column -4. syncedAtCol: set to the CURRENT_TIMESTAMP - -Algorithm: -1. Iterate over each unique set of unchanged toast column groups. -2. For each group, split it into individual column names. -3. Calculate the other columns by finding the set difference between allCols and the unchanged columns. -4. Generate an update statement for the current group by setting the appropriate conditions -and updating the other columns. - - The condition includes checking if the _PEERDB_RECORD_TYPE is not 2 (not a DELETE) and if the - _PEERDB_UNCHANGED_TOAST_COLUMNS match the current group. - - The update sets the other columns to their corresponding values - from the SOURCE table. It doesn't set (make null the Unchanged toast columns. - -5. Append the update statement to the list of generated statements. -6. Repeat steps 1-5 for each unique set of unchanged toast column groups. -7. Return the list of generated update statements. -*/ -func (c *SnowflakeConnector) generateUpdateStatements( - syncedAtCol string, softDeleteCol string, softDelete bool, - allCols []string, unchangedToastCols []string, -) []string { - updateStmts := make([]string, 0, len(unchangedToastCols)) - - for _, cols := range unchangedToastCols { - unchangedColsArray := strings.Split(cols, ",") - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0, len(otherCols)+2) - for _, colName := range otherCols { - normalizedColName := SnowflakeIdentifierNormalize(colName) - tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName)) - } - - // set the synced at column to the current timestamp - if syncedAtCol != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol)) - } - // set soft-deleted to false, tackles insert after soft-delete - if softDeleteCol != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, softDeleteCol)) - } - - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - - // generates update statements for the case where updates and deletes happen in the same branch - // the backfill has happened from the pull side already, so treat the DeleteRecord as an update - // and then set soft-delete to true. - if softDelete && (softDeleteCol != "") { - tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`, softDeleteCol)) - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - } - } - return updateStmts -} - func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { renameTablesTx, err := c.database.BeginTx(c.ctx, nil) if err != nil { diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index 1e6f318713..743dcb6419 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -7,6 +7,7 @@ import ( "log" "log/slog" "os" + "sync/atomic" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -19,7 +20,6 @@ import ( "github.com/klauspost/compress/snappy" "github.com/klauspost/compress/zstd" "github.com/linkedin/goavro/v2" - uber_atomic "go.uber.org/atomic" ) type ( @@ -128,8 +128,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ocfWriter *goavro.OCFWriter) ( colNames := schema.GetColumnNames() - var numRows uber_atomic.Uint32 - numRows.Store(0) + numRows := atomic.Uint32{} if p.ctx != nil { shutdown := utils.HeartbeatRoutine(p.ctx, 30*time.Second, func() string { @@ -164,7 +163,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ocfWriter *goavro.OCFWriter) ( return 0, fmt.Errorf("failed to write record to OCF: %w", err) } - numRows.Inc() + numRows.Add(1) } return int(numRows.Load()), nil diff --git a/flow/connectors/utils/identifiers.go b/flow/connectors/utils/identifiers.go index 5318605a93..19867971a9 100644 --- a/flow/connectors/utils/identifiers.go +++ b/flow/connectors/utils/identifiers.go @@ -49,3 +49,10 @@ func IsLower(s string) bool { } return true } + +func RemoveSpacesTabsNewlines(s string) string { + s = strings.ReplaceAll(s, " ", "") + s = strings.ReplaceAll(s, "\t", "") + s = strings.ReplaceAll(s, "\n", "") + return s +} diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index d17e4c6b26..8c45c29872 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -702,7 +702,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { c14 INET,c15 INTEGER,c16 INTERVAL,c17 JSON,c18 JSONB,c21 MACADDR,c22 MONEY, c23 NUMERIC,c24 OID,c28 REAL,c29 SMALLINT,c30 SMALLSERIAL,c31 SERIAL,c32 TEXT, c33 TIMESTAMP,c34 TIMESTAMPTZ,c35 TIME, c36 TIMETZ,c37 TSQUERY,c38 TSVECTOR, - c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 INT[], c43 FLOAT[], c44 TEXT[], c45 mood); + c39 TXID_SNAPSHOT,c40 UUID,c41 XML, c42 INT[], c43 FLOAT[], c44 TEXT[], c45 mood, c46 HSTORE); `, srcTableName)) require.NoError(s.t, err) @@ -737,8 +737,9 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { txid_current_snapshot(), '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid,xmlcomment('hello'), ARRAY[10299301,2579827], - ARRAY[0.0003, 8902.0092, 'NaN'], - ARRAY['hello','bye'],'happy'; + ARRAY[0.0003, 8902.0092], + ARRAY['hello','bye'],'happy', + 'key1=>value1, key2=>NULL'::hstore `, srcTableName)) e2e.EnvNoError(s.t, env, err) }() @@ -756,7 +757,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { "c41", "c1", "c2", "c3", "c4", "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", "c21", "c22", "c23", "c24", "c28", "c29", "c30", "c31", "c33", "c34", "c35", "c36", - "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", + "c37", "c38", "c7", "c8", "c32", "c42", "c43", "c44", "c45", "c46", }) if err != nil { s.t.Log(err) @@ -767,6 +768,14 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Types_BQ() { // check if JSON on bigquery side is a good JSON err = s.checkJSONValue(dstTableName, "c17", "sai", "-8.021390374331551") require.NoError(s.t, err) + + // check if HSTORE on bigquery side is a good JSON + err = s.checkJSONValue(dstTableName, "c46", "key1", "\"value1\"") + require.NoError(s.t, err) + err = s.checkJSONValue(dstTableName, "c46", "key2", "null") + require.NoError(s.t, err) + + env.AssertExpectations(s.t) } func (s PeerFlowE2ETestSuiteBQ) Test_NaN_Doubles_BQ() { diff --git a/flow/go.mod b/flow/go.mod index 66e9991c72..dd26c93819 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -33,7 +33,6 @@ require ( github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a go.temporal.io/api v1.26.0 go.temporal.io/sdk v1.25.1 - go.uber.org/atomic v1.11.0 go.uber.org/automaxprocs v1.5.3 golang.org/x/sync v0.5.0 google.golang.org/api v0.154.0 @@ -70,6 +69,7 @@ require ( go.opentelemetry.io/otel v1.21.0 // indirect go.opentelemetry.io/otel/metric v1.21.0 // indirect go.opentelemetry.io/otel/trace v1.21.0 // indirect + go.uber.org/atomic v1.11.0 // indirect ) require ( diff --git a/flow/hstore/hstore.go b/flow/hstore/hstore.go new file mode 100644 index 0000000000..0253fef2b1 --- /dev/null +++ b/flow/hstore/hstore.go @@ -0,0 +1,233 @@ +/* +This is in reference to PostgreSQL's hstore: +https://github.com/postgres/postgres/blob/bea18b1c949145ba2ca79d4765dba3cc9494a480/contrib/hstore/hstore_io.c + +This package is an implementation based on the above code. +It's simplified to only parse the subset which `hstore_out` outputs. +*/ +package hstore_util + +import ( + "encoding/json" + "errors" + "fmt" + "strings" +) + +type text struct { + String string + Valid bool +} + +type hstore map[string]*string + +type hstoreParser struct { + str string + pos int + nextBackslash int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), + } +} + +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) +} + +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { + if p.pos >= len(p.str) { + return 0, true + } + b = p.str[p.pos] + p.pos++ + return b, false +} + +func unexpectedByteErr(actualB byte, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) +} + +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) + } + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) + } + return nil +} + +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) + } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) + } + p.pos += 2 + return nil +} + +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) + +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + s := p.str[p.pos:nextDoubleQuote] + p.pos = nextDoubleQuote + 1 + return s, nil + } + + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos + } + return s, err +} + +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted + } + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) + } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) + } + } + return builder.String(), nil +} + +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (text, error) { + // peek at the next byte + if p.atEnd() { + return text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') + if err != nil { + return text{}, err + } + err = p.consumeExpected2('L', 'L') + if err != nil { + return text{}, err + } + return text{String: "", Valid: false}, nil + } else if next != '"' { + return text{}, unexpectedByteErr(next, '"') + } + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return text{}, err + } + return text{String: s, Valid: true}, nil +} + +func ParseHstore(s string) (string, error) { + p := newHSP(s) + + // This is an over-estimate of the number of key/value pairs. + numPairsEstimate := strings.Count(s, ">") + result := make(hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return "", err + } + } else { + first = false + } + + err := p.consumeExpectedByte('"') + if err != nil { + return "", err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return "", err + } + + err = p.consumeKVSeparator() + if err != nil { + return "", err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return "", err + } + if value.Valid { + result[key] = &value.String + } else { + result[key] = nil + } + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return "", err + } + + return string(jsonBytes), nil +} diff --git a/flow/hstore/hstore_test.go b/flow/hstore/hstore_test.go new file mode 100644 index 0000000000..721a6bf523 --- /dev/null +++ b/flow/hstore/hstore_test.go @@ -0,0 +1,103 @@ +package hstore_util + +import ( + "testing" +) + +func TestHStoreHappy(t *testing.T) { + testCase := `"a"=>"b", "c"=>"d"` + expected := `{"a":"b","c":"d"}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreEscapedQuotes(t *testing.T) { + testCase := `"a\"b"=>"c\"d"` + expected := `{"a\"b":"c\"d"}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreEscapedBackslashes(t *testing.T) { + testCase := `"a\\b"=>"c\\d"` + expected := `{"a\\b":"c\\d"}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreNullCase(t *testing.T) { + testCase := `"a"=>NULL` + expected := `{"a":null}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreDisguisedSeparator(t *testing.T) { + testCase := `"=>"=>"a=>b"` + expected := `{"=\u003e":"a=\u003eb"}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreEmpty(t *testing.T) { + testCase := `""=>" "` + expected := `{"":" "}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestHStoreDuplicate(t *testing.T) { + testCase := `"a"=>"b", "a"=>"c"` + expected := `{"a":"c"}` + + result, err := ParseHstore(testCase) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} diff --git a/flow/model/model.go b/flow/model/model.go index 15846b6d74..0eb82dc8a9 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -11,6 +11,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/generated/protos" + hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/peerdbenv" ) @@ -156,6 +157,23 @@ func (r *RecordItems) toMap() (map[string]interface{}, error) { } else { jsonStruct[col] = strVal } + case qvalue.QValueKindHStore: + hstoreVal, ok := v.Value.(string) + if !ok { + return nil, fmt.Errorf("expected string value for hstore column %s for value %T", col, v.Value) + } + + jsonVal, err := hstore_util.ParseHstore(hstoreVal) + if err != nil { + return nil, fmt.Errorf("unable to convert hstore column %s to json for value %T", col, v.Value) + } + + if len(jsonVal) > 15*1024*1024 { + jsonStruct[col] = "" + } else { + jsonStruct[col] = jsonVal + } + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate, qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: jsonStruct[col], err = v.GoTimeConvert() diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 4f9cbe2e47..9fafb8a18c 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + hstore_util "github.com/PeerDB-io/peer-flow/hstore" "github.com/google/uuid" "github.com/linkedin/goavro/v2" ) @@ -156,7 +157,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { case QValueKindJSON: return c.processJSON() case QValueKindHStore: - return nil, fmt.Errorf("QValueKindHStore not supported") + return c.processHStore() case QValueKindArrayFloat32: return c.processArrayFloat32() case QValueKindArrayFloat64: @@ -271,6 +272,38 @@ func (c *QValueAvroConverter) processJSON() (interface{}, error) { return jsonString, nil } +func (c *QValueAvroConverter) processHStore() (interface{}, error) { + if c.Value.Value == nil && c.Nullable { + return nil, nil + } + + hstoreString, ok := c.Value.Value.(string) + if !ok { + return nil, fmt.Errorf("invalid HSTORE value %v", c.Value.Value) + } + + jsonString, err := hstore_util.ParseHstore(hstoreString) + if err != nil { + return "", err + } + + if c.Nullable { + if c.TargetDWH == QDWHTypeSnowflake && len(jsonString) > 15*1024*1024 { + slog.Warn("Truncating HStore equivalent JSON value > 15MB for Snowflake!") + slog.Warn("Check this issue for details: https://github.com/PeerDB-io/peerdb/issues/309") + return goavro.Union("string", ""), nil + } + return goavro.Union("string", jsonString), nil + } + + if c.TargetDWH == QDWHTypeSnowflake && len(jsonString) > 15*1024*1024 { + slog.Warn("Truncating HStore equivalent JSON value > 15MB for Snowflake!") + slog.Warn("Check this issue for details: https://github.com/PeerDB-io/peerdb/issues/309") + return "", nil + } + return jsonString, nil +} + func (c *QValueAvroConverter) processUUID() (interface{}, error) { if c.Value.Value == nil { return nil, nil diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index 0b78a6dc43..b44e0df207 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -376,26 +376,7 @@ func (q *QRepFlowExecution) receiveAndHandleSignalAsync(ctx workflow.Context) { } } -func QRepFlowWorkflow( - ctx workflow.Context, - config *protos.QRepConfig, - state *protos.QRepFlowState, -) error { - // The structure of this workflow is as follows: - // 1. Start the loop to continuously run the replication flow. - // 2. In the loop, query the source database to get the partitions to replicate. - // 3. For each partition, start a new workflow to replicate the partition. - // 4. Wait for all the workflows to complete. - // 5. Sleep for a while and repeat the loop. - - ctx = workflow.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) - logger := workflow.GetLogger(ctx) - - maxParallelWorkers := 16 - if config.MaxParallelWorkers > 0 { - maxParallelWorkers = int(config.MaxParallelWorkers) - } - +func setWorkflowQueries(ctx workflow.Context, state *protos.QRepFlowState) error { // Support a Query for the current state of the qrep flow. err := workflow.SetQueryHandler(ctx, shared.QRepFlowStateQuery, func() (*protos.QRepFlowState, error) { return state, nil @@ -404,7 +385,7 @@ func QRepFlowWorkflow( return fmt.Errorf("failed to set `%s` query handler: %w", shared.QRepFlowStateQuery, err) } - // Support a Query for the current status of the arep flow. + // Support a Query for the current status of the qrep flow. err = workflow.SetQueryHandler(ctx, shared.FlowStatusQuery, func() (*protos.FlowStatus, error) { return &state.CurrentFlowState, nil }) @@ -420,6 +401,33 @@ func QRepFlowWorkflow( if err != nil { return fmt.Errorf("failed to register query handler: %w", err) } + return nil +} + +func QRepFlowWorkflow( + ctx workflow.Context, + config *protos.QRepConfig, + state *protos.QRepFlowState, +) error { + // The structure of this workflow is as follows: + // 1. Start the loop to continuously run the replication flow. + // 2. In the loop, query the source database to get the partitions to replicate. + // 3. For each partition, start a new workflow to replicate the partition. + // 4. Wait for all the workflows to complete. + // 5. Sleep for a while and repeat the loop. + + ctx = workflow.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) + logger := workflow.GetLogger(ctx) + + maxParallelWorkers := 16 + if config.MaxParallelWorkers > 0 { + maxParallelWorkers = int(config.MaxParallelWorkers) + } + + err := setWorkflowQueries(ctx, state) + if err != nil { + return err + } // get qrep run uuid via side-effect runUUIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { diff --git a/flow/workflows/xmin_flow.go b/flow/workflows/xmin_flow.go index 20fb7f701e..1394d17353 100644 --- a/flow/workflows/xmin_flow.go +++ b/flow/workflows/xmin_flow.go @@ -3,7 +3,6 @@ package peerflow import ( "fmt" - "strings" "time" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -36,172 +35,31 @@ func NewXminFlowExecution(ctx workflow.Context, config *protos.QRepConfig, runUU } } -// SetupMetadataTables creates the metadata tables for query based replication. -func (q *XminFlowExecution) SetupMetadataTables(ctx workflow.Context) error { - q.logger.Info("setting up metadata tables for xmin flow - ", q.config.FlowJobName) - - ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 5 * time.Minute, - }) - - if err := workflow.ExecuteActivity(ctx, flowable.SetupQRepMetadataTables, q.config).Get(ctx, nil); err != nil { - return fmt.Errorf("failed to setup metadata tables: %w", err) - } - - q.logger.Info("metadata tables setup for xmin flow - ", q.config.FlowJobName) - return nil -} - -func (q *XminFlowExecution) SetupWatermarkTableOnDestination(ctx workflow.Context) error { - if q.config.SetupWatermarkTableOnDestination { - q.logger.Info("setting up watermark table on destination for xmin flow: ", q.config.FlowJobName) - - ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 5 * time.Minute, - }) - - tableSchemaInput := &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: q.config.SourcePeer, - TableIdentifiers: []string{q.config.WatermarkTable}, - FlowName: q.config.FlowJobName, - SkipPkeyAndReplicaCheck: true, - } - - future := workflow.ExecuteActivity(ctx, flowable.GetTableSchema, tableSchemaInput) - - var tblSchemaOutput *protos.GetTableSchemaBatchOutput - if err := future.Get(ctx, &tblSchemaOutput); err != nil { - q.logger.Error("failed to fetch schema for watermark table: ", err) - return fmt.Errorf("failed to fetch schema for watermark table %s: %w", q.config.WatermarkTable, err) - } - - // now setup the normalized tables on the destination peer - setupConfig := &protos.SetupNormalizedTableBatchInput{ - PeerConnectionConfig: q.config.DestinationPeer, - TableNameSchemaMapping: map[string]*protos.TableSchema{ - q.config.DestinationTableIdentifier: tblSchemaOutput.TableNameSchemaMapping[q.config.WatermarkTable], - }, - FlowName: q.config.FlowJobName, - } - - future = workflow.ExecuteActivity(ctx, flowable.CreateNormalizedTable, setupConfig) - var createNormalizedTablesOutput *protos.SetupNormalizedTableBatchOutput - if err := future.Get(ctx, &createNormalizedTablesOutput); err != nil { - q.logger.Error("failed to create watermark table: ", err) - return fmt.Errorf("failed to create watermark table: %w", err) - } - q.logger.Info("finished setting up watermark table for xmin flow: ", q.config.FlowJobName) - } - return nil -} - -func (q *XminFlowExecution) handleTableCreationForResync(ctx workflow.Context, state *protos.QRepFlowState) error { - if state.NeedsResync && q.config.DstTableFullResync { - renamedTableIdentifier := fmt.Sprintf("%s_peerdb_resync", q.config.DestinationTableIdentifier) - createTablesFromExistingCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 10 * time.Minute, - HeartbeatTimeout: 2 * time.Minute, - }) - createTablesFromExistingFuture := workflow.ExecuteActivity( - createTablesFromExistingCtx, flowable.CreateTablesFromExisting, &protos.CreateTablesFromExistingInput{ - FlowJobName: q.config.FlowJobName, - Peer: q.config.DestinationPeer, - NewToExistingTableMapping: map[string]string{ - renamedTableIdentifier: q.config.DestinationTableIdentifier, - }, - }) - if err := createTablesFromExistingFuture.Get(createTablesFromExistingCtx, nil); err != nil { - return fmt.Errorf("failed to create table for mirror resync: %w", err) - } - q.config.DestinationTableIdentifier = renamedTableIdentifier - } - return nil -} - -func (q *XminFlowExecution) handleTableRenameForResync(ctx workflow.Context, state *protos.QRepFlowState) error { - if state.NeedsResync && q.config.DstTableFullResync { - oldTableIdentifier := strings.TrimSuffix(q.config.DestinationTableIdentifier, "_peerdb_resync") - renameOpts := &protos.RenameTablesInput{} - renameOpts.FlowJobName = q.config.FlowJobName - renameOpts.Peer = q.config.DestinationPeer - renameOpts.RenameTableOptions = []*protos.RenameTableOption{ - { - CurrentName: q.config.DestinationTableIdentifier, - NewName: oldTableIdentifier, - }, - } - - renameTablesCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 30 * time.Minute, - HeartbeatTimeout: 5 * time.Minute, - }) - renameTablesFuture := workflow.ExecuteActivity(renameTablesCtx, flowable.RenameTables, renameOpts) - if err := renameTablesFuture.Get(renameTablesCtx, nil); err != nil { - return fmt.Errorf("failed to execute rename tables activity: %w", err) - } - q.config.DestinationTableIdentifier = oldTableIdentifier - } - state.NeedsResync = false - return nil -} - -func (q *XminFlowExecution) receiveAndHandleSignalAsync(ctx workflow.Context) { - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) - - var signalVal shared.CDCFlowSignal - ok := signalChan.ReceiveAsync(&signalVal) - if ok { - q.activeSignal = shared.FlowSignalHandler(q.activeSignal, signalVal, q.logger) - } -} - -// For some targets we need to consolidate all the partitions from stages before -// we proceed to next batch. -func (q *XminFlowExecution) consolidatePartitions(ctx workflow.Context) error { - q.logger.Info("consolidating partitions") - - // only an operation for Snowflake currently. - ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 24 * time.Hour, - HeartbeatTimeout: 10 * time.Minute, - }) - - if err := workflow.ExecuteActivity(ctx, flowable.ConsolidateQRepPartitions, q.config, - q.runUUID).Get(ctx, nil); err != nil { - return fmt.Errorf("failed to consolidate partitions: %w", err) - } - - q.logger.Info("partitions consolidated") - - // clean up qrep flow as well - if err := workflow.ExecuteActivity(ctx, flowable.CleanupQRepFlow, q.config).Get(ctx, nil); err != nil { - return fmt.Errorf("failed to cleanup qrep flow: %w", err) - } - - q.logger.Info("xmin flow cleaned up") - - return nil -} - func XminFlowWorkflow( ctx workflow.Context, config *protos.QRepConfig, state *protos.QRepFlowState, ) error { ctx = workflow.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) + // Support a Query for the current state of the xmin flow. + err := setWorkflowQueries(ctx, state) + if err != nil { + return err + } + // get xmin run uuid via side-effect runUUIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { return uuid.New().String() }) - var runUUID string if err := runUUIDSideEffect.Get(&runUUID); err != nil { return fmt.Errorf("failed to get run uuid: %w", err) } - q := NewXminFlowExecution(ctx, config, runUUID) + x := NewXminFlowExecution(ctx, config, runUUID) - err := q.SetupWatermarkTableOnDestination(ctx) + q := NewQRepFlowExecution(ctx, config, runUUID) + err = q.SetupWatermarkTableOnDestination(ctx) if err != nil { return fmt.Errorf("failed to setup watermark table: %w", err) } @@ -210,7 +68,7 @@ func XminFlowWorkflow( if err != nil { return fmt.Errorf("failed to setup metadata tables: %w", err) } - q.logger.Info("metadata tables setup for peer flow - ", config.FlowJobName) + x.logger.Info("metadata tables setup for peer flow - ", config.FlowJobName) err = q.handleTableCreationForResync(ctx, state) if err != nil { @@ -225,9 +83,9 @@ func XminFlowWorkflow( err = workflow.ExecuteActivity( replicateXminPartitionCtx, flowable.ReplicateXminPartition, - q.config, + x.config, state.LastPartition, - q.runUUID, + x.runUUID, ).Get(ctx, &lastPartition) if err != nil { return fmt.Errorf("xmin replication failed: %w", err) @@ -238,7 +96,7 @@ func XminFlowWorkflow( } if config.InitialCopyOnly { - q.logger.Info("initial copy completed for peer flow - ", config.FlowJobName) + x.logger.Info("initial copy completed for peer flow - ", config.FlowJobName) return nil } @@ -248,7 +106,7 @@ func XminFlowWorkflow( } state.LastPartition = &protos.QRepPartition{ - PartitionId: q.runUUID, + PartitionId: x.runUUID, Range: &protos.PartitionRange{Range: &protos.PartitionRange_IntRange{IntRange: &protos.IntPartitionRange{Start: lastPartition}}}, } @@ -259,22 +117,22 @@ func XminFlowWorkflow( // here, we handle signals after the end of the flow because a new workflow does not inherit the signals // and the chance of missing a signal is much higher if the check is before the time consuming parts run q.receiveAndHandleSignalAsync(ctx) - if q.activeSignal == shared.PauseSignal { + if x.activeSignal == shared.PauseSignal { startTime := time.Now() signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) var signalVal shared.CDCFlowSignal - for q.activeSignal == shared.PauseSignal { - q.logger.Info("mirror has been paused for ", time.Since(startTime)) + for x.activeSignal == shared.PauseSignal { + x.logger.Info("mirror has been paused for ", time.Since(startTime)) // only place we block on receive, so signal processing is immediate ok, _ := signalChan.ReceiveWithTimeout(ctx, 1*time.Minute, &signalVal) if ok { - q.activeSignal = shared.FlowSignalHandler(q.activeSignal, signalVal, q.logger) + x.activeSignal = shared.FlowSignalHandler(x.activeSignal, signalVal, x.logger) } } } - if q.activeSignal == shared.ShutdownSignal { - q.logger.Info("terminating workflow - ", config.FlowJobName) + if x.activeSignal == shared.ShutdownSignal { + x.logger.Info("terminating workflow - ", config.FlowJobName) return nil } diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs index 6e469eddcf..bbbbf531d8 100644 --- a/nexus/analyzer/src/lib.rs +++ b/nexus/analyzer/src/lib.rs @@ -10,8 +10,8 @@ use anyhow::Context; use pt::{ flow_model::{FlowJob, FlowJobTableMapping, QRepFlowJob}, peerdb_peers::{ - peer::Config, BigqueryConfig, DbType, EventHubConfig, MongoConfig, Peer, PostgresConfig, - S3Config, SnowflakeConfig, SqlServerConfig,ClickhouseConfig + peer::Config, BigqueryConfig, ClickhouseConfig, DbType, EventHubConfig, MongoConfig, Peer, + PostgresConfig, S3Config, SnowflakeConfig, SqlServerConfig, }, }; use qrep::process_options; @@ -798,7 +798,7 @@ fn parse_db_options( .get("s3_integration") .map(|s| s.to_string()) .unwrap_or_default(); - + let clickhouse_config = ClickhouseConfig { host: opts.get("host").context("no host specified")?.to_string(), port: opts @@ -822,7 +822,7 @@ fn parse_db_options( }; let config = Config::ClickhouseConfig(clickhouse_config); Some(config) - } + } }; Ok(config) diff --git a/nexus/catalog/src/lib.rs b/nexus/catalog/src/lib.rs index bb36277043..9dedbb7a33 100644 --- a/nexus/catalog/src/lib.rs +++ b/nexus/catalog/src/lib.rs @@ -143,7 +143,7 @@ impl Catalog { let config_len = clickhouse_config.encoded_len(); buf.reserve(config_len); clickhouse_config.encode(&mut buf)?; - } + } }; buf @@ -340,11 +340,14 @@ impl Catalog { Ok(Some(Config::EventhubGroupConfig(eventhub_group_config))) } Some(DbType::Clickhouse) => { - let err = format!("unable to decode {} options for peer {}", "clickhouse", name); + let err = format!( + "unable to decode {} options for peer {}", + "clickhouse", name + ); let clickhouse_config = pt::peerdb_peers::ClickhouseConfig::decode(options).context(err)?; Ok(Some(Config::ClickhouseConfig(clickhouse_config))) - } + } None => Ok(None), } } diff --git a/nexus/parser/src/lib.rs b/nexus/parser/src/lib.rs index f99dbe8751..4a305c7899 100644 --- a/nexus/parser/src/lib.rs +++ b/nexus/parser/src/lib.rs @@ -11,13 +11,12 @@ use pgwire::{ error::{ErrorInfo, PgWireError, PgWireResult}, }; use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; -use tokio::sync::Mutex; const DIALECT: PostgreSqlDialect = PostgreSqlDialect {}; #[derive(Clone)] pub struct NexusQueryParser { - catalog: Arc>, + catalog: Arc, } #[derive(Debug, Clone)] @@ -93,13 +92,12 @@ pub struct NexusParsedStatement { } impl NexusQueryParser { - pub fn new(catalog: Arc>) -> Self { + pub fn new(catalog: Arc) -> Self { Self { catalog } } pub async fn get_peers_bridge(&self) -> PgWireResult> { - let catalog = self.catalog.lock().await; - let peers = catalog.get_peers().await; + let peers = self.catalog.get_peers().await; peers.map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( diff --git a/nexus/peer-bigquery/src/cursor.rs b/nexus/peer-bigquery/src/cursor.rs index 23812a382a..52558600ef 100644 --- a/nexus/peer-bigquery/src/cursor.rs +++ b/nexus/peer-bigquery/src/cursor.rs @@ -1,8 +1,7 @@ use dashmap::DashMap; -use tokio::sync::Mutex; use futures::StreamExt; -use peer_cursor::{QueryExecutor, QueryOutput, Records, SchemaRef, SendableStream}; +use peer_cursor::{QueryExecutor, QueryOutput, Records, Schema, SendableStream}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use sqlparser::ast::Statement; @@ -10,8 +9,8 @@ use crate::BigQueryQueryExecutor; pub struct BigQueryCursor { position: usize, - stream: Mutex, - schema: SchemaRef, + stream: SendableStream, + schema: Schema, } pub struct BigQueryCursorManager { @@ -42,7 +41,7 @@ impl BigQueryCursorManager { // Create a new cursor let cursor = BigQueryCursor { position: 0, - stream: Mutex::new(stream), + stream, schema, }; @@ -75,9 +74,8 @@ impl BigQueryCursorManager { let prev_end = cursor.position; let mut cursor_position = cursor.position; { - let mut stream = cursor.stream.lock().await; while cursor_position - prev_end < count { - match stream.next().await { + match cursor.stream.next().await { Some(Ok(record)) => { records.push(record); cursor_position += 1; diff --git a/nexus/peer-bigquery/src/lib.rs b/nexus/peer-bigquery/src/lib.rs index 29d58fb24f..e0f9fa99f3 100644 --- a/nexus/peer-bigquery/src/lib.rs +++ b/nexus/peer-bigquery/src/lib.rs @@ -7,7 +7,7 @@ use gcp_bigquery_client::{ Client, }; use peer_connections::PeerConnectionTracker; -use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, SchemaRef}; +use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, Schema}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pt::peerdb_peers::BigqueryConfig; use sqlparser::ast::{CloseCursor, Expr, FetchDirection, Statement, Value}; @@ -200,7 +200,7 @@ impl QueryExecutor for BigQueryQueryExecutor { } // describe the output of the query - async fn describe(&self, stmt: &Statement) -> PgWireResult> { + async fn describe(&self, stmt: &Statement) -> PgWireResult> { // print the statement tracing::info!("[bigquery] describe: {}", stmt); // only support SELECT statements diff --git a/nexus/peer-bigquery/src/stream.rs b/nexus/peer-bigquery/src/stream.rs index a831f6818f..b5738d914f 100644 --- a/nexus/peer-bigquery/src/stream.rs +++ b/nexus/peer-bigquery/src/stream.rs @@ -1,6 +1,7 @@ use std::{ pin::Pin, str::FromStr, + sync::Arc, task::{Context, Poll}, }; @@ -9,7 +10,7 @@ use futures::Stream; use gcp_bigquery_client::model::{ field_type::FieldType, query_response::ResultSet, table_field_schema::TableFieldSchema, }; -use peer_cursor::{Record, RecordStream, Schema, SchemaRef}; +use peer_cursor::{Record, RecordStream, Schema}; use pgwire::{ api::{ results::{FieldFormat, FieldInfo}, @@ -22,7 +23,7 @@ use value::Value; #[derive(Debug)] pub struct BqSchema { - schema: SchemaRef, + schema: Schema, fields: Vec, } @@ -68,15 +69,15 @@ impl BqSchema { .as_ref() .expect("Schema fields are not present"); - let schema = SchemaRef::new(Schema { - fields: fields + let schema = Arc::new( + fields .iter() .map(|field| { let datatype = convert_field_type(&field.r#type); FieldInfo::new(field.name.clone(), None, None, datatype, FieldFormat::Text) }) .collect(), - }); + ); Self { schema, @@ -84,7 +85,7 @@ impl BqSchema { } } - pub fn schema(&self) -> SchemaRef { + pub fn schema(&self) -> Schema { self.schema.clone() } } @@ -192,7 +193,7 @@ impl Stream for BqRecordStream { } impl RecordStream for BqRecordStream { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> Schema { self.schema.schema() } } diff --git a/nexus/peer-cursor/src/lib.rs b/nexus/peer-cursor/src/lib.rs index 7d2525a7df..5c3b297080 100644 --- a/nexus/peer-cursor/src/lib.rs +++ b/nexus/peer-cursor/src/lib.rs @@ -7,27 +7,22 @@ use value::Value; pub mod util; -#[derive(Debug, Clone)] -pub struct Schema { - pub fields: Vec, -} - -pub type SchemaRef = Arc; +pub type Schema = Arc>; pub struct Record { pub values: Vec, - pub schema: SchemaRef, + pub schema: Schema, } pub trait RecordStream: Stream> { - fn schema(&self) -> SchemaRef; + fn schema(&self) -> Schema; } -pub type SendableStream = Pin>; +pub type SendableStream = Pin>; pub struct Records { pub records: Vec, - pub schema: SchemaRef, + pub schema: Schema, } #[derive(Debug, Clone)] @@ -50,7 +45,7 @@ pub enum QueryOutput { pub trait QueryExecutor: Send + Sync { async fn execute(&self, stmt: &Statement) -> PgWireResult; - async fn describe(&self, stmt: &Statement) -> PgWireResult>; + async fn describe(&self, stmt: &Statement) -> PgWireResult>; async fn is_connection_valid(&self) -> anyhow::Result; } diff --git a/nexus/peer-cursor/src/util.rs b/nexus/peer-cursor/src/util.rs index e9b9d55b00..568a81da5b 100644 --- a/nexus/peer-cursor/src/util.rs +++ b/nexus/peer-cursor/src/util.rs @@ -1,13 +1,11 @@ -use std::sync::Arc; - use futures::{stream, StreamExt}; use pgwire::{ - api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response}, + api::results::{DataRowEncoder, QueryResponse, Response}, error::{PgWireError, PgWireResult}, }; use value::Value; -use crate::{Records, SchemaRef, SendableStream}; +use crate::{Records, Schema, SendableStream}; fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> { match value { @@ -58,11 +56,10 @@ fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> } pub fn sendable_stream_to_query_response<'a>( - schema: SchemaRef, + schema: Schema, record_stream: SendableStream, ) -> PgWireResult> { - let pg_schema: Arc> = Arc::new(schema.fields.clone()); - let schema_copy = pg_schema.clone(); + let schema_copy = schema.clone(); let data_row_stream = record_stream .map(move |record_result| { @@ -76,15 +73,11 @@ pub fn sendable_stream_to_query_response<'a>( }) .boxed(); - Ok(Response::Query(QueryResponse::new( - pg_schema, - data_row_stream, - ))) + Ok(Response::Query(QueryResponse::new(schema, data_row_stream))) } pub fn records_to_query_response<'a>(records: Records) -> PgWireResult> { - let pg_schema: Arc> = Arc::new(records.schema.fields.clone()); - let schema_copy = pg_schema.clone(); + let schema_copy = records.schema.clone(); let data_row_stream = stream::iter(records.records) .map(move |record| { @@ -97,7 +90,7 @@ pub fn records_to_query_response<'a>(records: Records) -> PgWireResult anyhow::Result { + pub async fn schema_from_query(&self, query: &str) -> anyhow::Result { let prepared = self.client.prepare_typed(query, &[]).await?; let fields: Vec = prepared @@ -42,7 +42,7 @@ impl PostgresQueryExecutor { }) .collect(); - Ok(Arc::new(Schema { fields })) + Ok(Arc::new(fields)) } } @@ -113,7 +113,7 @@ impl QueryExecutor for PostgresQueryExecutor { } } - async fn describe(&self, stmt: &Statement) -> PgWireResult> { + async fn describe(&self, stmt: &Statement) -> PgWireResult> { match stmt { Statement::Query(_query) => { let schema = self diff --git a/nexus/peer-postgres/src/stream.rs b/nexus/peer-postgres/src/stream.rs index 21905c1cc6..230d2dca7d 100644 --- a/nexus/peer-postgres/src/stream.rs +++ b/nexus/peer-postgres/src/stream.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use futures::Stream; -use peer_cursor::{Record, RecordStream, SchemaRef}; +use peer_cursor::{Record, RecordStream, Schema}; use pgwire::error::{PgWireError, PgWireResult}; use postgres_inet::MaskedIpAddr; use rust_decimal::Decimal; @@ -14,11 +14,11 @@ use uuid::Uuid; use value::{array::ArrayValue, Value}; pub struct PgRecordStream { row_stream: Pin>, - schema: SchemaRef, + schema: Schema, } impl PgRecordStream { - pub fn new(row_stream: RowStream, schema: SchemaRef) -> Self { + pub fn new(row_stream: RowStream, schema: Schema) -> Self { Self { row_stream: Box::pin(row_stream), schema, @@ -277,7 +277,7 @@ impl Stream for PgRecordStream { } impl RecordStream for PgRecordStream { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> Schema { self.schema.clone() } } diff --git a/nexus/peer-snowflake/src/cursor.rs b/nexus/peer-snowflake/src/cursor.rs index 475a2d7f35..318a6d04d8 100644 --- a/nexus/peer-snowflake/src/cursor.rs +++ b/nexus/peer-snowflake/src/cursor.rs @@ -1,15 +1,14 @@ use crate::SnowflakeQueryExecutor; use dashmap::DashMap; use futures::StreamExt; -use peer_cursor::{QueryExecutor, QueryOutput, Records, SchemaRef, SendableStream}; +use peer_cursor::{QueryExecutor, QueryOutput, Records, Schema, SendableStream}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use sqlparser::ast::Statement; -use tokio::sync::Mutex; pub struct SnowflakeCursor { position: usize, - stream: Mutex, - schema: SchemaRef, + stream: SendableStream, + schema: Schema, } pub struct SnowflakeCursorManager { @@ -39,7 +38,7 @@ impl SnowflakeCursorManager { // Create a new cursor let cursor = SnowflakeCursor { position: 0, - stream: Mutex::new(stream), + stream, schema, }; @@ -72,9 +71,8 @@ impl SnowflakeCursorManager { let prev_end = cursor.position; let mut cursor_position = cursor.position; { - let mut stream = cursor.stream.lock().await; while cursor_position - prev_end < count { - match stream.next().await { + match cursor.stream.next().await { Some(Ok(record)) => { records.push(record); cursor_position += 1; diff --git a/nexus/peer-snowflake/src/lib.rs b/nexus/peer-snowflake/src/lib.rs index ac4d0154d9..c58ad68902 100644 --- a/nexus/peer-snowflake/src/lib.rs +++ b/nexus/peer-snowflake/src/lib.rs @@ -1,7 +1,7 @@ use anyhow::Context; use async_recursion::async_recursion; use cursor::SnowflakeCursorManager; -use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, SchemaRef}; +use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, Schema}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use sqlparser::dialect::GenericDialect; use sqlparser::parser; @@ -395,7 +395,7 @@ impl QueryExecutor for SnowflakeQueryExecutor { } } - async fn describe(&self, stmt: &Statement) -> PgWireResult> { + async fn describe(&self, stmt: &Statement) -> PgWireResult> { match stmt { Statement::Query(query) => { let mut new_query = query.clone(); diff --git a/nexus/peer-snowflake/src/stream.rs b/nexus/peer-snowflake/src/stream.rs index 4740270d12..efac7b7e1f 100644 --- a/nexus/peer-snowflake/src/stream.rs +++ b/nexus/peer-snowflake/src/stream.rs @@ -1,8 +1,7 @@ use crate::{auth::SnowflakeAuth, PartitionResult, ResultSet}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use futures::Stream; -use peer_cursor::Schema; -use peer_cursor::{Record, RecordStream, SchemaRef}; +use peer_cursor::{Record, RecordStream, Schema}; use pgwire::{ api::{ results::{FieldFormat, FieldInfo}, @@ -14,6 +13,7 @@ use secrecy::ExposeSecret; use serde::Deserialize; use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, }; use value::Value::{ @@ -40,7 +40,7 @@ pub(crate) enum SnowflakeDataType { } pub struct SnowflakeSchema { - schema: SchemaRef, + schema: Schema, } fn convert_field_type(field_type: &SnowflakeDataType) -> Type { @@ -63,20 +63,20 @@ impl SnowflakeSchema { pub fn from_result_set(result_set: &ResultSet) -> Self { let fields = result_set.resultSetMetaData.rowType.clone(); - let schema = SchemaRef::new(Schema { - fields: fields + let schema = Arc::new( + fields .iter() .map(|field| { let datatype = convert_field_type(&field.r#type); FieldInfo::new(field.name.clone(), None, None, datatype, FieldFormat::Text) }) .collect(), - }); + ); Self { schema } } - pub fn schema(&self) -> SchemaRef { + pub fn schema(&self) -> Schema { self.schema.clone() } } @@ -249,7 +249,7 @@ impl Stream for SnowflakeRecordStream { } impl RecordStream for SnowflakeRecordStream { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> Schema { self.schema.schema() } } diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index bb2219512e..103e2b0537 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -13,11 +13,12 @@ use clap::Parser; use cursor::PeerCursors; use dashmap::{mapref::entry::Entry as DashEntry, DashMap}; use flow_rs::grpc::{FlowGrpcClient, PeerValidationResult}; +use futures::join; use peer_bigquery::BigQueryQueryExecutor; use peer_connections::{PeerConnectionTracker, PeerConnections}; use peer_cursor::{ util::{records_to_query_response, sendable_stream_to_query_response}, - QueryExecutor, QueryOutput, SchemaRef, + QueryExecutor, QueryOutput, Schema, }; use peerdb_parser::{NexusParsedStatement, NexusQueryParser, NexusStatement}; use pgwire::{ @@ -40,7 +41,7 @@ use pt::{ }; use rand::Rng; use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::{Mutex, MutexGuard}; +use tokio::sync::Mutex; use tokio::{io::AsyncWriteExt, net::TcpListener}; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -78,7 +79,7 @@ impl AuthSource for FixedPasswordAuthSource { } pub struct NexusBackend { - catalog: Arc>, + catalog: Arc, peer_connections: PeerConnectionTracker, query_parser: NexusQueryParser, peer_cursors: Mutex, @@ -89,7 +90,7 @@ pub struct NexusBackend { impl NexusBackend { pub fn new( - catalog: Arc>, + catalog: Arc, peer_connections: PeerConnectionTracker, flow_handler: Option>>, peerdb_fdw_mode: bool, @@ -161,7 +162,7 @@ impl NexusBackend { } async fn check_for_mirror( - catalog: &MutexGuard<'_, Catalog>, + catalog: &Catalog, flow_name: &str, ) -> PgWireResult> { let workflow_details = catalog @@ -175,10 +176,7 @@ impl NexusBackend { Ok(workflow_details) } - async fn get_peer_of_mirror( - catalog: &MutexGuard<'_, Catalog>, - peer_name: &str, - ) -> PgWireResult { + async fn get_peer_of_mirror(catalog: &Catalog, peer_name: &str) -> PgWireResult { let peer = catalog.get_peer(peer_name).await.map_err(|err| { PgWireError::ApiError(format!("unable to get peer {:?}: {:?}", peer_name, err).into()) })?; @@ -251,13 +249,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "DROP MIRROR: mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -284,7 +282,7 @@ impl NexusBackend { format!("unable to shutdown flow job: {:?}", err).into(), ) })?; - catalog + self.catalog .delete_flow_job_entry(flow_job_name) .await .map_err(|err| { @@ -334,14 +332,13 @@ impl NexusBackend { } let mirror_details; { - let catalog = self.catalog.lock().await; mirror_details = - Self::check_for_mirror(&catalog, &qrep_flow_job.name).await?; + Self::check_for_mirror(self.catalog.as_ref(), &qrep_flow_job.name) + .await?; } if mirror_details.is_none() { { - let catalog = self.catalog.lock().await; - catalog + self.catalog .create_qrep_flow_job_entry(qrep_flow_job) .await .map_err(|err| { @@ -399,8 +396,7 @@ impl NexusBackend { })?; } - let catalog = self.catalog.lock().await; - catalog.create_peer(peer.as_ref()).await.map_err(|e| { + self.catalog.create_peer(peer.as_ref()).await.map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "internal_error".to_owned(), @@ -420,8 +416,8 @@ impl NexusBackend { "flow service is not configured".into(), )); } - let catalog = self.catalog.lock().await; - let mirror_details = Self::check_for_mirror(&catalog, &flow_job.name).await?; + let mirror_details = + Self::check_for_mirror(self.catalog.as_ref(), &flow_job.name).await?; if mirror_details.is_none() { // reject duplicate source tables or duplicate target tables let table_mappings_count = flow_job.table_mappings.len(); @@ -450,7 +446,7 @@ impl NexusBackend { } } - catalog + self.catalog .create_cdc_flow_job_entry(flow_job) .await .map_err(|err| { @@ -460,10 +456,12 @@ impl NexusBackend { })?; // get source and destination peers - let src_peer = - Self::get_peer_of_mirror(&catalog, &flow_job.source_peer).await?; - let dst_peer = - Self::get_peer_of_mirror(&catalog, &flow_job.target_peer).await?; + let (src_peer, dst_peer) = join!( + Self::get_peer_of_mirror(self.catalog.as_ref(), &flow_job.source_peer), + Self::get_peer_of_mirror(self.catalog.as_ref(), &flow_job.target_peer), + ); + let src_peer = src_peer?; + let dst_peer = dst_peer?; // make a request to the flow service to start the job. let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -476,7 +474,7 @@ impl NexusBackend { ) })?; - catalog + self.catalog .update_workflow_id_for_flow_job(&flow_job.name, &workflow_id) .await .map_err(|err| { @@ -505,8 +503,7 @@ impl NexusBackend { } if let Some(job) = { - let catalog = self.catalog.lock().await; - catalog + self.catalog .get_qrep_flow_job_by_name(flow_job_name) .await .map_err(|err| { @@ -540,17 +537,21 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "DROP PEER: peer_name: {}, if_exists: {}", peer_name, if_exists ); - let peer_exists = catalog.check_peer_entry(peer_name).await.map_err(|err| { - PgWireError::ApiError( - format!("unable to query catalog for peer metadata: {:?}", err).into(), - ) - })?; + let peer_exists = + self.catalog + .check_peer_entry(peer_name) + .await + .map_err(|err| { + PgWireError::ApiError( + format!("unable to query catalog for peer metadata: {:?}", err) + .into(), + ) + })?; tracing::info!("peer exist count: {}", peer_exists); if peer_exists != 0 { let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -590,8 +591,7 @@ impl NexusBackend { let qrep_config = { // retrieve the mirror job since DROP MIRROR will delete the row later. - let catalog = self.catalog.lock().await; - catalog + self.catalog .get_qrep_config_proto(mirror_name) .await .map_err(|err| { @@ -632,8 +632,7 @@ impl NexusBackend { ) })?; // relock catalog, DROP MIRROR is done with it now - let catalog = self.catalog.lock().await; - catalog + self.catalog .update_workflow_id_for_flow_job( &qrep_config.flow_job_name, &workflow_id, @@ -674,13 +673,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "[PAUSE MIRROR] mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -737,13 +736,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "[RESUME MIRROR] mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -805,8 +804,7 @@ impl NexusBackend { } QueryAssociation::Catalog => { tracing::info!("handling catalog query: {}", stmt); - let catalog = self.catalog.lock().await; - Arc::clone(catalog.get_executor()) + Arc::clone(self.catalog.get_executor()) } }; @@ -829,10 +827,7 @@ impl NexusBackend { analyzer::CursorEvent::Close(c) => peer_cursors.get_peer(&c), }; match peer { - None => { - let catalog = self.catalog.lock().await; - Arc::clone(catalog.get_executor()) - } + None => Arc::clone(self.catalog.get_executor()), Some(peer) => self.get_peer_executor(peer).await.map_err(|err| { PgWireError::ApiError( format!("unable to get peer executor: {:?}", err).into(), @@ -850,22 +845,18 @@ impl NexusBackend { } async fn run_qrep_mirror(&self, qrep_flow_job: &QRepFlowJob) -> PgWireResult { - let catalog = self.catalog.lock().await; - + let (src_peer, dst_peer) = join!( + self.catalog.get_peer(&qrep_flow_job.source_peer), + self.catalog.get_peer(&qrep_flow_job.target_peer), + ); // get source and destination peers - let src_peer = catalog - .get_peer(&qrep_flow_job.source_peer) - .await - .map_err(|err| { - PgWireError::ApiError(format!("unable to get source peer: {:?}", err).into()) - })?; + let src_peer = src_peer.map_err(|err| { + PgWireError::ApiError(format!("unable to get source peer: {:?}", err).into()) + })?; - let dst_peer = catalog - .get_peer(&qrep_flow_job.target_peer) - .await - .map_err(|err| { - PgWireError::ApiError(format!("unable to get destination peer: {:?}", err).into()) - })?; + let dst_peer = dst_peer.map_err(|err| { + PgWireError::ApiError(format!("unable to get destination peer: {:?}", err).into()) + })?; // make a request to the flow service to start the job. let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -876,7 +867,7 @@ impl NexusBackend { PgWireError::ApiError(format!("unable to submit job: {:?}", err).into()) })?; - catalog + self.catalog .update_workflow_id_for_flow_job(&qrep_flow_job.name, &workflow_id) .await .map_err(|err| { @@ -1045,7 +1036,7 @@ impl ExtendedQueryHandler for NexusBackend { NexusStatement::PeerCursor { .. } => Ok(DescribeResponse::no_data()), NexusStatement::Empty => Ok(DescribeResponse::no_data()), NexusStatement::PeerQuery { stmt, assoc } => { - let schema: Option = match assoc { + let schema: Option = match assoc { QueryAssociation::Peer(peer) => { // if the peer is of type bigquery, let us route the query to bq. match &peer.config { @@ -1087,11 +1078,7 @@ impl ExtendedQueryHandler for NexusBackend { } } } - QueryAssociation::Catalog => { - let catalog = self.catalog.lock().await; - let executor = catalog.get_executor(); - executor.describe(stmt).await? - } + QueryAssociation::Catalog => self.catalog.get_executor().describe(stmt).await?, }; if let Some(described_schema) = schema { if self.peerdb_fdw_mode { @@ -1099,7 +1086,7 @@ impl ExtendedQueryHandler for NexusBackend { } else { Ok(DescribeResponse::new( param_types, - described_schema.fields.clone(), + (*described_schema).clone(), )) } } else { @@ -1320,7 +1307,7 @@ pub async fn main() -> anyhow::Result<()> { let tracker = PeerConnectionTracker::new(conn_uuid, conn_peer_conns); let processor = Arc::new(NexusBackend::new( - Arc::new(Mutex::new(catalog)), + Arc::new(catalog), tracker, conn_flow_handler, peerdb_fdw_mode, diff --git a/protos/route.proto b/protos/route.proto index db330104b3..577be49a4f 100644 --- a/protos/route.proto +++ b/protos/route.proto @@ -114,7 +114,12 @@ message SchemaTablesRequest { } message SchemaTablesResponse { - repeated string tables = 1; + repeated TableResponse tables = 1; +} + +message TableResponse { + string table_name = 1; + bool can_mirror = 2; } message AllTablesResponse { diff --git a/ui/app/alert-config/layout.tsx b/ui/app/alert-config/layout.tsx index 2cabddbcbb..69a53b44ea 100644 --- a/ui/app/alert-config/layout.tsx +++ b/ui/app/alert-config/layout.tsx @@ -1,12 +1,7 @@ import SidebarComponent from '@/components/SidebarComponent'; import { Layout } from '@/lib/Layout'; -import { cookies } from 'next/headers'; import { PropsWithChildren } from 'react'; export default function PageLayout({ children }: PropsWithChildren) { - return ( - }> - {children} - - ); + return }>{children}; } diff --git a/ui/app/api/alert-config/route.ts b/ui/app/api/alert-config/route.ts index 64d5a24f06..6966d2c146 100644 --- a/ui/app/api/alert-config/route.ts +++ b/ui/app/api/alert-config/route.ts @@ -1,10 +1,9 @@ import { alertConfigType } from '@/app/alert-config/validation'; import prisma from '@/app/utils/prisma'; - -export const dynamic = 'force-dynamic'; +import { alerting_config } from '@prisma/client'; export async function GET() { - const configs = await prisma.alerting_config.findMany(); + const configs: alerting_config[] = await prisma.alerting_config.findMany(); const serializedConfigs = configs.map((config) => ({ ...config, id: String(config.id), diff --git a/ui/app/api/auth/[...nextauth]/route.ts b/ui/app/api/auth/[...nextauth]/route.ts new file mode 100644 index 0000000000..a19cf278ac --- /dev/null +++ b/ui/app/api/auth/[...nextauth]/route.ts @@ -0,0 +1,5 @@ +import { authOptions } from '@/app/auth/options'; +import NextAuth from 'next-auth'; + +const handler = NextAuth(authOptions); +export { handler as GET, handler as POST }; diff --git a/ui/app/api/login/route.ts b/ui/app/api/login/route.ts deleted file mode 100644 index 6a6bf202e6..0000000000 --- a/ui/app/api/login/route.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { cookies } from 'next/headers'; - -export async function POST(request: Request) { - const { password } = await request.json(); - if (process.env.PEERDB_PASSWORD !== password) { - return new Response( - JSON.stringify({ - error: - 'Your password is incorrect. Please check your password and try again.', - }) - ); - } - cookies().set('auth', password, { - expires: Date.now() + 24 * 60 * 60 * 1000, - secure: process.env.PEERDB_SECURE_COOKIES === 'true', - }); - return new Response('{}'); -} diff --git a/ui/app/api/logout/route.ts b/ui/app/api/logout/route.ts deleted file mode 100644 index 5faf5a8c84..0000000000 --- a/ui/app/api/logout/route.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { cookies } from 'next/headers'; - -export async function POST(req: Request) { - cookies().delete('auth'); - return new Response(''); -} diff --git a/ui/app/api/mirrors/alerts/route.ts b/ui/app/api/mirrors/alerts/route.ts index 13cc612503..3596c4a294 100644 --- a/ui/app/api/mirrors/alerts/route.ts +++ b/ui/app/api/mirrors/alerts/route.ts @@ -1,7 +1,5 @@ import prisma from '@/app/utils/prisma'; -export const dynamic = 'force-dynamic'; - export async function POST(request: Request) { const { flowName } = await request.json(); const errCount = await prisma.flow_errors.count({ diff --git a/ui/app/api/peers/tables/all/route.ts b/ui/app/api/peers/tables/all/route.ts index 92223f7bbf..0281cc7067 100644 --- a/ui/app/api/peers/tables/all/route.ts +++ b/ui/app/api/peers/tables/all/route.ts @@ -1,5 +1,5 @@ -import { UTablesResponse } from '@/app/dto/PeersDTO'; -import { SchemaTablesResponse } from '@/grpc_generated/route'; +import { UTablesAllResponse } from '@/app/dto/PeersDTO'; +import { AllTablesResponse } from '@/grpc_generated/route'; import { GetFlowHttpAddressFromEnv } from '@/rpc/http'; export async function POST(request: Request) { @@ -7,12 +7,12 @@ export async function POST(request: Request) { const { peerName } = body; const flowServiceAddr = GetFlowHttpAddressFromEnv(); try { - const tableList: SchemaTablesResponse = await fetch( + const tableList: AllTablesResponse = await fetch( `${flowServiceAddr}/v1/peers/tables/all?peer_name=${peerName}` ).then((res) => { return res.json(); }); - let response: UTablesResponse = { + let response: UTablesAllResponse = { tables: tableList.tables, }; return new Response(JSON.stringify(response)); diff --git a/ui/app/auth/options.ts b/ui/app/auth/options.ts new file mode 100644 index 0000000000..8d68e1b0ea --- /dev/null +++ b/ui/app/auth/options.ts @@ -0,0 +1,38 @@ +import { Configuration } from '@/app/config/config'; +import { AuthOptions } from 'next-auth'; +import CredentialsProvider from 'next-auth/providers/credentials'; +import { Provider } from 'next-auth/providers/index'; + +function getEnabledProviders(): Provider[] { + return [ + CredentialsProvider({ + name: 'Password', + credentials: { + password: { label: 'Password', type: 'password' }, + }, + async authorize(credentials, req) { + if ( + credentials == null || + credentials.password != Configuration.authentication.PEERDB_PASSWORD + ) { + return null; + } + return { id: '1', name: 'Admin' }; + }, + }), + ]; +} +export const authOptions: AuthOptions = { + providers: getEnabledProviders(), + debug: false, + session: { + strategy: 'jwt', + maxAge: 60 * 60, // 1h + }, + // adapter: PrismaAdapter(prisma), + secret: Configuration.authentication.NEXTAUTH_SECRET, + theme: { + colorScheme: 'light', + logo: '/images/peerdb-combinedMark.svg', + }, +}; diff --git a/ui/app/config/config.ts b/ui/app/config/config.ts new file mode 100644 index 0000000000..4e613fa084 --- /dev/null +++ b/ui/app/config/config.ts @@ -0,0 +1,7 @@ +export var Configuration = { + authentication: { + PEERDB_PASSWORD: process.env.PEERDB_PASSWORD, + // Set this in production to a static value + NEXTAUTH_SECRET: process.env.NEXTAUTH_SECRET, + }, +}; diff --git a/ui/app/dto/MirrorsDTO.ts b/ui/app/dto/MirrorsDTO.ts index 4a76200fd4..e33890fe03 100644 --- a/ui/app/dto/MirrorsDTO.ts +++ b/ui/app/dto/MirrorsDTO.ts @@ -20,6 +20,7 @@ export type TableMapRow = { partitionKey: string; exclude: string[]; selected: boolean; + canMirror: boolean; }; export type SyncStatusRow = { diff --git a/ui/app/dto/PeersDTO.ts b/ui/app/dto/PeersDTO.ts index 232e1b2c3c..80de38124b 100644 --- a/ui/app/dto/PeersDTO.ts +++ b/ui/app/dto/PeersDTO.ts @@ -4,6 +4,7 @@ import { S3Config, SnowflakeConfig, } from '@/grpc_generated/peers'; +import { TableResponse } from '@/grpc_generated/route'; export type UValidatePeerResponse = { valid: boolean; @@ -20,6 +21,10 @@ export type USchemasResponse = { }; export type UTablesResponse = { + tables: TableResponse[]; +}; + +export type UTablesAllResponse = { tables: string[]; }; diff --git a/ui/app/login/page.tsx b/ui/app/login/page.tsx deleted file mode 100644 index 1172b0e811..0000000000 --- a/ui/app/login/page.tsx +++ /dev/null @@ -1,106 +0,0 @@ -'use client'; -import Image from 'next/image'; -import { useRouter, useSearchParams } from 'next/navigation'; -import { useState } from 'react'; - -import { Button } from '@/lib/Button'; -import { Icon } from '@/lib/Icon'; -import { Layout } from '@/lib/Layout'; -import { TextField } from '@/lib/TextField'; - -export default function Login() { - const router = useRouter(); - const searchParams = useSearchParams(); - const [pass, setPass] = useState(''); - const [show, setShow] = useState(false); - const [error, setError] = useState(() => ''); - const login = () => { - fetch('/api/login', { - method: 'POST', - body: JSON.stringify({ password: pass }), - }) - .then((res) => res.json()) - .then((res) => { - if (res.error) { - setError(res.error); - } else router.push('/'); - }); - }; - return ( - -
- PeerDB -
- ) => - setPass(e.target.value) - } - onKeyDown={(e: React.KeyboardEvent) => { - if (e.key === 'Enter') { - login(); - } - }} - /> - -
- - {error && ( -
- {error} -
- )} -
-
- ); -} diff --git a/ui/app/mirrors/create/cdc/schemabox.tsx b/ui/app/mirrors/create/cdc/schemabox.tsx index d030fbc74a..34c49c0b16 100644 --- a/ui/app/mirrors/create/cdc/schemabox.tsx +++ b/ui/app/mirrors/create/cdc/schemabox.tsx @@ -7,6 +7,7 @@ import { Label } from '@/lib/Label'; import { RowWithCheckbox } from '@/lib/Layout'; import { SearchField } from '@/lib/SearchField'; import { TextField } from '@/lib/TextField'; +import { Tooltip } from '@/lib/Tooltip'; import { Dispatch, SetStateAction, @@ -17,7 +18,12 @@ import { import { BarLoader } from 'react-spinners/'; import { fetchColumns, fetchTables } from '../handlers'; import ColumnBox from './columnbox'; -import { expandableStyle, schemaBoxStyle, tableBoxStyle } from './styles'; +import { + expandableStyle, + schemaBoxStyle, + tableBoxStyle, + tooltipStyle, +} from './styles'; interface SchemaBoxProps { sourcePeer: string; @@ -210,12 +216,29 @@ const SchemaBox = ({ > - {row.source} - + + + } action={ handleTableSelect(state, row.source) @@ -223,7 +246,6 @@ const SchemaBox = ({ /> } /> -
res.json()); - let tables = []; - const tableNames = tablesRes.tables; - if (tableNames) { - for (const tableName of tableNames) { + let tables: TableMapRow[] = []; + const tableRes = tablesRes.tables; + if (tableRes) { + for (const tableObject of tableRes) { // setting defaults: // for bigquery, tables are not schema-qualified const dstName = peerType != undefined && dBTypeToJSON(peerType) == 'BIGQUERY' - ? tableName - : `${schemaName}.${tableName}`; + ? tableObject.tableName + : `${schemaName}.${tableObject.tableName}`; tables.push({ schema: schemaName, - source: `${schemaName}.${tableName}`, + source: `${schemaName}.${tableObject.tableName}`, destination: dstName, partitionKey: '', exclude: [], selected: false, + canMirror: tableObject.canMirror, }); } } @@ -325,7 +327,7 @@ export const fetchColumns = async ( export const fetchAllTables = async (peerName: string) => { if (peerName?.length === 0) return []; - const tablesRes: UTablesResponse = await fetch('/api/peers/tables/all', { + const tablesRes: UTablesAllResponse = await fetch('/api/peers/tables/all', { method: 'POST', body: JSON.stringify({ peerName, diff --git a/ui/app/mirrors/edit/[mirrorId]/page.tsx b/ui/app/mirrors/edit/[mirrorId]/page.tsx index 7e5de948f9..4f68dcf20c 100644 --- a/ui/app/mirrors/edit/[mirrorId]/page.tsx +++ b/ui/app/mirrors/edit/[mirrorId]/page.tsx @@ -8,8 +8,6 @@ import { CDCMirror } from './cdc'; import NoMirror from './nomirror'; import SyncStatus from './syncStatus'; -export const dynamic = 'force-dynamic'; - type EditMirrorProps = { params: { mirrorId: string }; }; @@ -21,7 +19,7 @@ function getMirrorStatusUrl(mirrorId: string) { async function getMirrorStatus(mirrorId: string) { const url = getMirrorStatusUrl(mirrorId); - const resp = await fetch(url); + const resp = await fetch(url, { cache: 'no-store' }); const json = await resp.json(); return json; } diff --git a/ui/app/mirrors/layout.tsx b/ui/app/mirrors/layout.tsx index 2cabddbcbb..69a53b44ea 100644 --- a/ui/app/mirrors/layout.tsx +++ b/ui/app/mirrors/layout.tsx @@ -1,12 +1,7 @@ import SidebarComponent from '@/components/SidebarComponent'; import { Layout } from '@/lib/Layout'; -import { cookies } from 'next/headers'; import { PropsWithChildren } from 'react'; export default function PageLayout({ children }: PropsWithChildren) { - return ( - }> - {children} - - ); + return }>{children}; } diff --git a/ui/app/mirrors/page.tsx b/ui/app/mirrors/page.tsx index 59bee320d3..0fe7012806 100644 --- a/ui/app/mirrors/page.tsx +++ b/ui/app/mirrors/page.tsx @@ -13,8 +13,6 @@ import useSWR from 'swr'; import { fetcher } from '../utils/swr'; import { CDCFlows, QRepFlows } from './tables'; -export const dynamic = 'force-dynamic'; - export default function Mirrors() { const { data: flows, diff --git a/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx b/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx index c0a40085b3..eeaec8cb77 100644 --- a/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx +++ b/ui/app/mirrors/status/qrep/[mirrorId]/page.tsx @@ -5,8 +5,6 @@ import QRepConfigViewer from './qrepConfigViewer'; import QrepGraph from './qrepGraph'; import QRepStatusTable, { QRepPartitionStatus } from './qrepStatusTable'; -export const dynamic = 'force-dynamic'; - type QRepMirrorStatusProps = { params: { mirrorId: string }; }; diff --git a/ui/app/mirrors/status/qrep/[mirrorId]/qrepConfigViewer.tsx b/ui/app/mirrors/status/qrep/[mirrorId]/qrepConfigViewer.tsx index ac070d744d..b822fe29e7 100644 --- a/ui/app/mirrors/status/qrep/[mirrorId]/qrepConfigViewer.tsx +++ b/ui/app/mirrors/status/qrep/[mirrorId]/qrepConfigViewer.tsx @@ -5,8 +5,6 @@ import { Icon } from '@/lib/Icon'; import { Label } from '@/lib/Label'; import { ProgressCircle } from '@/lib/ProgressCircle'; -export const dynamic = 'force-dynamic'; - type QRepConfigViewerProps = { mirrorId: string; }; diff --git a/ui/app/page.tsx b/ui/app/page.tsx index d89f981887..4e84221bb7 100644 --- a/ui/app/page.tsx +++ b/ui/app/page.tsx @@ -1,11 +1,10 @@ import SidebarComponent from '@/components/SidebarComponent'; import { Header } from '@/lib/Header'; import { Layout, LayoutMain } from '@/lib/Layout'; -import { cookies } from 'next/headers'; export default function Home() { return ( - }> + }>
PeerDB Home Page
diff --git a/ui/app/peers/[peerName]/page.tsx b/ui/app/peers/[peerName]/page.tsx index 78bd430bf3..45ddb41e61 100644 --- a/ui/app/peers/[peerName]/page.tsx +++ b/ui/app/peers/[peerName]/page.tsx @@ -5,7 +5,6 @@ import { GetFlowHttpAddressFromEnv } from '@/rpc/http'; import Link from 'next/link'; import SlotTable from './slottable'; import StatTable from './stattable'; -export const dynamic = 'force-dynamic'; type DataConfigProps = { params: { peerName: string }; diff --git a/ui/app/peers/layout.tsx b/ui/app/peers/layout.tsx index 2cabddbcbb..69a53b44ea 100644 --- a/ui/app/peers/layout.tsx +++ b/ui/app/peers/layout.tsx @@ -1,12 +1,7 @@ import SidebarComponent from '@/components/SidebarComponent'; import { Layout } from '@/lib/Layout'; -import { cookies } from 'next/headers'; import { PropsWithChildren } from 'react'; export default function PageLayout({ children }: PropsWithChildren) { - return ( - }> - {children} - - ); + return }>{children}; } diff --git a/ui/components/Logout.tsx b/ui/components/Logout.tsx index dc262245f5..8f33e6f8d9 100644 --- a/ui/components/Logout.tsx +++ b/ui/components/Logout.tsx @@ -1,17 +1,22 @@ 'use client'; import { Button } from '@/lib/Button'; +import { useSession } from 'next-auth/react'; +import { useRouter } from 'next/navigation'; export default function Logout() { - return ( - - ); + const { data: session } = useSession(); + const router = useRouter(); + if (session) { + return ( + + ); + } } diff --git a/ui/components/SidebarComponent.tsx b/ui/components/SidebarComponent.tsx index 4f0304770e..0ecf06598b 100644 --- a/ui/components/SidebarComponent.tsx +++ b/ui/components/SidebarComponent.tsx @@ -8,9 +8,11 @@ import { Icon } from '@/lib/Icon'; import { Label } from '@/lib/Label'; import { RowWithSelect } from '@/lib/Layout'; import { Sidebar, SidebarItem } from '@/lib/Sidebar'; +import { SessionProvider } from 'next-auth/react'; import Link from 'next/link'; import useSWR from 'swr'; import { useLocalStorage } from 'usehooks-ts'; + const centerFlexStyle = { display: 'flex', justifyContent: 'center', @@ -19,7 +21,7 @@ const centerFlexStyle = { marginBottom: '0.5rem', }; -export default function SidebarComponent(props: { logout?: boolean }) { +export default function SidebarComponent(props: {}) { const timezones = ['UTC', 'Local', 'Relative']; const [zone, setZone] = useLocalStorage('timezone-ui', ''); @@ -33,78 +35,80 @@ export default function SidebarComponent(props: { logout?: boolean }) { ); return ( - -
- -
- - } - bottomRow={ - <> -
-
- Timezone:} - action={ - - } - /> + + +
+
-
- {props.logout && } - - } - bottomLabel={ -
- -
- } - > - } - > - Peers - - } - > - Mirrors - - } + } + bottomRow={ + <> +
+
+ Timezone:} + action={ + + } + /> +
+
+ + + } + bottomLabel={ +
+ +
+ } > - Alert Configuration -
- + } + > + Peers + + } + > + Mirrors + + } + > + Alert Configuration + + + ); } diff --git a/ui/middleware.ts b/ui/middleware.ts index 3f616b1a7a..81dcfdbcf4 100644 --- a/ui/middleware.ts +++ b/ui/middleware.ts @@ -1,18 +1,13 @@ -import type { NextRequest } from 'next/server'; -import { NextResponse } from 'next/server'; +import { Configuration } from '@/app/config/config'; +import { withAuth } from 'next-auth/middleware'; +import { NextRequest, NextResponse } from 'next/server'; -export default function middleware(req: NextRequest) { - if ( - req.nextUrl.pathname !== '/login' && - req.nextUrl.pathname !== '/api/login' && - req.nextUrl.pathname !== '/api/logout' && - process.env.PEERDB_PASSWORD && - req.cookies.get('auth')?.value !== process.env.PEERDB_PASSWORD - ) { - req.cookies.delete('auth'); - return NextResponse.redirect(new URL('/login?reject', req.url)); +const authMiddleware = withAuth({}); + +export default async function middleware(req: NextRequest, resp: NextResponse) { + if (Configuration.authentication.PEERDB_PASSWORD) { + return (authMiddleware as any)(req); } - return NextResponse.next(); } export const config = { diff --git a/ui/package-lock.json b/ui/package-lock.json index 99502c605c..4b99f89c1e 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -36,6 +36,7 @@ "moment": "^2.30.1", "moment-timezone": "^0.5.44", "next": "^14.0.4", + "next-auth": "^4.24.5", "prop-types": "^15.8.1", "protobufjs": "^7.2.5", "react": "18.2.0", @@ -3572,6 +3573,14 @@ "node": ">= 8" } }, + "node_modules/@panva/hkdf": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@panva/hkdf/-/hkdf-1.1.1.tgz", + "integrity": "sha512-dhPeilub1NuIG0X5Kvhh9lH4iW3ZsHlnzwgwbOlgwQ2wG1IqFzsgHqmKPk3WzsdWAeaxKJxgM0+W433RmN45GA==", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/@pkgjs/parseargs": { "version": "0.11.0", "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", @@ -9628,7 +9637,6 @@ "version": "0.5.0", "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", - "dev": true, "engines": { "node": ">= 0.6" } @@ -14012,6 +14020,14 @@ "jiti": "bin/jiti.js" } }, + "node_modules/jose": { + "version": "4.15.4", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.4.tgz", + "integrity": "sha512-W+oqK4H+r5sITxfxpSU+MMdr/YSWGvgZMQDIsNoBDGGy4i7GBPTtvFKibQzW06n3U3TqHjhvBJsirShsEJ6eeQ==", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -14987,6 +15003,41 @@ } } }, + "node_modules/next-auth": { + "version": "4.24.5", + "resolved": "https://registry.npmjs.org/next-auth/-/next-auth-4.24.5.tgz", + "integrity": "sha512-3RafV3XbfIKk6rF6GlLE4/KxjTcuMCifqrmD+98ejFq73SRoj2rmzoca8u764977lH/Q7jo6Xu6yM+Re1Mz/Og==", + "dependencies": { + "@babel/runtime": "^7.20.13", + "@panva/hkdf": "^1.0.2", + "cookie": "^0.5.0", + "jose": "^4.11.4", + "oauth": "^0.9.15", + "openid-client": "^5.4.0", + "preact": "^10.6.3", + "preact-render-to-string": "^5.1.19", + "uuid": "^8.3.2" + }, + "peerDependencies": { + "next": "^12.2.5 || ^13 || ^14", + "nodemailer": "^6.6.5", + "react": "^17.0.2 || ^18", + "react-dom": "^17.0.2 || ^18" + }, + "peerDependenciesMeta": { + "nodemailer": { + "optional": true + } + } + }, + "node_modules/next-auth/node_modules/uuid": { + "version": "8.3.2", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", + "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "bin": { + "uuid": "dist/bin/uuid" + } + }, "node_modules/next/node_modules/postcss": { "version": "8.4.31", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", @@ -15372,6 +15423,11 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/oauth": { + "version": "0.9.15", + "resolved": "https://registry.npmjs.org/oauth/-/oauth-0.9.15.tgz", + "integrity": "sha512-a5ERWK1kh38ExDEfoO6qUHJb32rd7aYmPHuyCu3Fta/cnICvYmgd2uhuKXvPD+PXB+gCEYYEaQdIRAjCOwAKNA==" + }, "node_modules/object-assign": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", @@ -15525,6 +15581,14 @@ "integrity": "sha512-zuHHiGTYTA1sYJ/wZN+t5HKZaH23i4yI1HMwbuXm24Nid7Dv0KcuRlKoNKS9UNfAVSBlnGLcuQrnOKWOZoEGaw==", "dev": true }, + "node_modules/oidc-token-hash": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", + "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==", + "engines": { + "node": "^10.13.0 || >=12.0.0" + } + }, "node_modules/on-finished": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", @@ -15587,6 +15651,44 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/openid-client": { + "version": "5.6.4", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.6.4.tgz", + "integrity": "sha512-T1h3B10BRPKfcObdBklX639tVz+xh34O7GjofqrqiAQdm7eHsQ00ih18x6wuJ/E6FxdtS2u3FmUGPDeEcMwzNA==", + "dependencies": { + "jose": "^4.15.4", + "lru-cache": "^6.0.0", + "object-hash": "^2.2.0", + "oidc-token-hash": "^5.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, + "node_modules/openid-client/node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/openid-client/node_modules/object-hash": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", + "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/openid-client/node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" + }, "node_modules/optionator": { "version": "0.9.3", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", @@ -16278,6 +16380,31 @@ "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==" }, + "node_modules/preact": { + "version": "10.19.3", + "resolved": "https://registry.npmjs.org/preact/-/preact-10.19.3.tgz", + "integrity": "sha512-nHHTeFVBTHRGxJXKkKu5hT8C/YWBkPso4/Gad6xuj5dbptt9iF9NZr9pHbPhBrnT2klheu7mHTxTZ/LjwJiEiQ==", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/preact" + } + }, + "node_modules/preact-render-to-string": { + "version": "5.2.6", + "resolved": "https://registry.npmjs.org/preact-render-to-string/-/preact-render-to-string-5.2.6.tgz", + "integrity": "sha512-JyhErpYOvBV1hEPwIxc/fHWXPfnEGdRKxc8gFdAZ7XV4tlzyzG847XAyEZqoDnynP88akM4eaHcSOzNcLWFguw==", + "dependencies": { + "pretty-format": "^3.8.0" + }, + "peerDependencies": { + "preact": ">=10" + } + }, + "node_modules/preact-render-to-string/node_modules/pretty-format": { + "version": "3.8.0", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-3.8.0.tgz", + "integrity": "sha512-WuxUnVtlWL1OfZFQFuqvnvs6MiAGk9UNsBostyBOB0Is9wb5uRESevA6rnl/rkksXaGX3GzZhPup5d6Vp1nFew==" + }, "node_modules/prebuild-install": { "version": "7.1.1", "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz", diff --git a/ui/package.json b/ui/package.json index 3ba478cb94..4fc310e676 100644 --- a/ui/package.json +++ b/ui/package.json @@ -42,6 +42,7 @@ "moment": "^2.30.1", "moment-timezone": "^0.5.44", "next": "^14.0.4", + "next-auth": "^4.24.5", "prop-types": "^15.8.1", "protobufjs": "^7.2.5", "react": "18.2.0",