diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index da3b8724e0..3a076d1e66 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -287,7 +287,6 @@ func (a *FlowableActivity) SyncFlow( } defer connectors.CloseConnector(ctx, dstConn) - logger.Info("pulling records...") tblNameMapping := make(map[string]model.NameAndExclude, len(options.TableMappings)) for _, v := range options.TableMappings { tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude) @@ -315,6 +314,7 @@ func (a *FlowableActivity) SyncFlow( if err != nil { return nil, err } + logger.Info("pulling records...", slog.Int64("LastOffset", lastOffset)) // start a goroutine to pull records from the source recordBatch := model.NewCDCRecordStream() @@ -346,7 +346,11 @@ func (a *FlowableActivity) SyncFlow( err = errGroup.Wait() if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) - return nil, fmt.Errorf("failed in pull records when: %w", err) + if temporal.IsApplicationError(err) { + return nil, err + } else { + return nil, fmt.Errorf("failed in pull records when: %w", err) + } } logger.Info("no records to push") @@ -402,7 +406,11 @@ func (a *FlowableActivity) SyncFlow( err = errGroup.Wait() if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) - return nil, fmt.Errorf("failed to pull records: %w", err) + if temporal.IsApplicationError(err) { + return nil, err + } else { + return nil, fmt.Errorf("failed to pull records: %w", err) + } } numRecords := res.NumRecordsSynced @@ -411,6 +419,7 @@ func (a *FlowableActivity) SyncFlow( logger.Info(fmt.Sprintf("pushed %d records in %d seconds", numRecords, int(syncDuration.Seconds()))) lastCheckpoint := recordBatch.GetLastCheckpoint() + srcConn.UpdateReplStateLastOffset(lastCheckpoint) err = monitoring.UpdateNumRowsAndEndLSNForCDCBatch( ctx, diff --git a/flow/alerting/alerting.go b/flow/alerting/alerting.go index 93cb946f78..acdbe6c290 100644 --- a/flow/alerting/alerting.go +++ b/flow/alerting/alerting.go @@ -3,8 +3,10 @@ package alerting import ( "context" "encoding/json" + "errors" "fmt" "log/slog" + "strings" "time" "github.com/jackc/pgx/v5" @@ -23,32 +25,67 @@ type Alerter struct { telemetrySender telemetry.Sender } -func (a *Alerter) registerSendersFromPool(ctx context.Context) ([]*slackAlertSender, error) { +type AlertSenderConfig struct { + Id int64 + Sender AlertSender +} + +func (a *Alerter) registerSendersFromPool(ctx context.Context) ([]AlertSenderConfig, error) { rows, err := a.catalogPool.Query(ctx, - "SELECT service_type,service_config FROM peerdb_stats.alerting_config") + "SELECT id,service_type,service_config FROM peerdb_stats.alerting_config") if err != nil { return nil, fmt.Errorf("failed to read alerter config from catalog: %w", err) } - var slackAlertSenders []*slackAlertSender - var serviceType, serviceConfig string - _, err = pgx.ForEachRow(rows, []any{&serviceType, &serviceConfig}, func() error { + var alertSenderConfigs []AlertSenderConfig + var serviceType ServiceType + var serviceConfig string + var id int64 + _, err = pgx.ForEachRow(rows, []any{&id, &serviceType, &serviceConfig}, func() error { switch serviceType { - case "slack": + case SLACK: var slackServiceConfig slackAlertConfig err = json.Unmarshal([]byte(serviceConfig), &slackServiceConfig) if err != nil { - return fmt.Errorf("failed to unmarshal Slack service config: %w", err) + return fmt.Errorf("failed to unmarshal %s service config: %w", serviceType, err) + } + + alertSenderConfigs = append(alertSenderConfigs, AlertSenderConfig{Id: id, Sender: newSlackAlertSender(&slackServiceConfig)}) + case EMAIL: + var replyToAddresses []string + if replyToEnvString := strings.TrimSpace( + peerdbenv.PeerDBAlertingEmailSenderReplyToAddresses()); replyToEnvString != "" { + replyToAddresses = strings.Split(replyToEnvString, ",") + } + emailServiceConfig := EmailAlertSenderConfig{ + sourceEmail: peerdbenv.PeerDBAlertingEmailSenderSourceEmail(), + configurationSetName: peerdbenv.PeerDBAlertingEmailSenderConfigurationSet(), + replyToAddresses: replyToAddresses, + } + if emailServiceConfig.sourceEmail == "" { + return errors.New("missing sourceEmail for Email alerting service") + } + err = json.Unmarshal([]byte(serviceConfig), &emailServiceConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal %s service config: %w", serviceType, err) + } + var region *string + if envRegion := peerdbenv.PeerDBAlertingEmailSenderRegion(); envRegion != "" { + region = &envRegion } - slackAlertSenders = append(slackAlertSenders, newSlackAlertSender(&slackServiceConfig)) + alertSender, alertSenderErr := NewEmailAlertSenderWithNewClient(ctx, region, &emailServiceConfig) + if alertSenderErr != nil { + return fmt.Errorf("failed to initialize email alerter: %w", alertSenderErr) + } + alertSenderConfigs = append(alertSenderConfigs, AlertSenderConfig{Id: id, Sender: alertSender}) default: return fmt.Errorf("unknown service type: %s", serviceType) } return nil }) - return slackAlertSenders, nil + return alertSenderConfigs, nil } // doesn't take care of closing pool, needs to be done externally. @@ -75,9 +112,9 @@ func NewAlerter(ctx context.Context, catalogPool *pgxpool.Pool) *Alerter { } func (a *Alerter) AlertIfSlotLag(ctx context.Context, peerName string, slotInfo *protos.SlotInfo) { - slackAlertSenders, err := a.registerSendersFromPool(ctx) + alertSenderConfigs, err := a.registerSendersFromPool(ctx) if err != nil { - logger.LoggerFromCtx(ctx).Warn("failed to set Slack senders", slog.Any("error", err)) + logger.LoggerFromCtx(ctx).Warn("failed to set alert senders", slog.Any("error", err)) return } @@ -89,29 +126,30 @@ func (a *Alerter) AlertIfSlotLag(ctx context.Context, peerName string, slotInfo defaultSlotLagMBAlertThreshold := dynamicconf.PeerDBSlotLagMBAlertThreshold(ctx) // catalog cannot use default threshold to space alerts properly, use the lowest set threshold instead lowestSlotLagMBAlertThreshold := defaultSlotLagMBAlertThreshold - for _, slackAlertSender := range slackAlertSenders { - if slackAlertSender.slotLagMBAlertThreshold > 0 { - lowestSlotLagMBAlertThreshold = min(lowestSlotLagMBAlertThreshold, slackAlertSender.slotLagMBAlertThreshold) + for _, alertSender := range alertSenderConfigs { + if alertSender.Sender.getSlotLagMBAlertThreshold() > 0 { + lowestSlotLagMBAlertThreshold = min(lowestSlotLagMBAlertThreshold, alertSender.Sender.getSlotLagMBAlertThreshold()) } } - alertKey := peerName + "-slot-lag-threshold-exceeded" + alertKey := fmt.Sprintf("%s Slot Lag Threshold Exceeded for Peer %s", deploymentUIDPrefix, peerName) alertMessageTemplate := fmt.Sprintf("%sSlot `%s` on peer `%s` has exceeded threshold size of %%dMB, "+ - `currently at %.2fMB! - cc: `, deploymentUIDPrefix, slotInfo.SlotName, peerName, slotInfo.LagInMb) - - if slotInfo.LagInMb > float32(lowestSlotLagMBAlertThreshold) && - a.checkAndAddAlertToCatalog(ctx, alertKey, fmt.Sprintf(alertMessageTemplate, lowestSlotLagMBAlertThreshold)) { - for _, slackAlertSender := range slackAlertSenders { - if slackAlertSender.slotLagMBAlertThreshold > 0 { - if slotInfo.LagInMb > float32(slackAlertSender.slotLagMBAlertThreshold) { - a.alertToSlack(ctx, slackAlertSender, alertKey, - fmt.Sprintf(alertMessageTemplate, slackAlertSender.slotLagMBAlertThreshold)) - } - } else { - if slotInfo.LagInMb > float32(defaultSlotLagMBAlertThreshold) { - a.alertToSlack(ctx, slackAlertSender, alertKey, - fmt.Sprintf(alertMessageTemplate, defaultSlotLagMBAlertThreshold)) + `currently at %.2fMB!`, deploymentUIDPrefix, slotInfo.SlotName, peerName, slotInfo.LagInMb) + + if slotInfo.LagInMb > float32(lowestSlotLagMBAlertThreshold) { + for _, alertSenderConfig := range alertSenderConfigs { + if a.checkAndAddAlertToCatalog(ctx, + alertSenderConfig.Id, alertKey, fmt.Sprintf(alertMessageTemplate, lowestSlotLagMBAlertThreshold)) { + if alertSenderConfig.Sender.getSlotLagMBAlertThreshold() > 0 { + if slotInfo.LagInMb > float32(alertSenderConfig.Sender.getSlotLagMBAlertThreshold()) { + a.alertToProvider(ctx, alertSenderConfig, alertKey, + fmt.Sprintf(alertMessageTemplate, alertSenderConfig.Sender.getSlotLagMBAlertThreshold())) + } + } else { + if slotInfo.LagInMb > float32(defaultSlotLagMBAlertThreshold) { + a.alertToProvider(ctx, alertSenderConfig, alertKey, + fmt.Sprintf(alertMessageTemplate, defaultSlotLagMBAlertThreshold)) + } } } } @@ -121,7 +159,7 @@ func (a *Alerter) AlertIfSlotLag(ctx context.Context, peerName string, slotInfo func (a *Alerter) AlertIfOpenConnections(ctx context.Context, peerName string, openConnections *protos.GetOpenConnectionsForUserResult, ) { - slackAlertSenders, err := a.registerSendersFromPool(ctx) + alertSenderConfigs, err := a.registerSendersFromPool(ctx) if err != nil { logger.LoggerFromCtx(ctx).Warn("failed to set Slack senders", slog.Any("error", err)) return @@ -129,44 +167,45 @@ func (a *Alerter) AlertIfOpenConnections(ctx context.Context, peerName string, deploymentUIDPrefix := "" if peerdbenv.PeerDBDeploymentUID() != "" { - deploymentUIDPrefix = fmt.Sprintf("[%s] ", peerdbenv.PeerDBDeploymentUID()) + deploymentUIDPrefix = fmt.Sprintf("[%s] - ", peerdbenv.PeerDBDeploymentUID()) } // same as with slot lag, use lowest threshold for catalog defaultOpenConnectionsThreshold := dynamicconf.PeerDBOpenConnectionsAlertThreshold(ctx) lowestOpenConnectionsThreshold := defaultOpenConnectionsThreshold - for _, slackAlertSender := range slackAlertSenders { - if slackAlertSender.openConnectionsAlertThreshold > 0 { - lowestOpenConnectionsThreshold = min(lowestOpenConnectionsThreshold, slackAlertSender.openConnectionsAlertThreshold) + for _, alertSender := range alertSenderConfigs { + if alertSender.Sender.getOpenConnectionsAlertThreshold() > 0 { + lowestOpenConnectionsThreshold = min(lowestOpenConnectionsThreshold, alertSender.Sender.getOpenConnectionsAlertThreshold()) } } - alertKey := peerName + "-max-open-connections-threshold-exceeded" + alertKey := fmt.Sprintf("%s Max Open Connections Threshold Exceeded for Peer %s", deploymentUIDPrefix, peerName) alertMessageTemplate := fmt.Sprintf("%sOpen connections from PeerDB user `%s` on peer `%s`"+ - ` has exceeded threshold size of %%d connections, currently at %d connections! - cc: `, deploymentUIDPrefix, openConnections.UserName, peerName, openConnections.CurrentOpenConnections) - - if openConnections.CurrentOpenConnections > int64(lowestOpenConnectionsThreshold) && - a.checkAndAddAlertToCatalog(ctx, alertKey, fmt.Sprintf(alertMessageTemplate, lowestOpenConnectionsThreshold)) { - for _, slackAlertSender := range slackAlertSenders { - if slackAlertSender.openConnectionsAlertThreshold > 0 { - if openConnections.CurrentOpenConnections > int64(slackAlertSender.openConnectionsAlertThreshold) { - a.alertToSlack(ctx, slackAlertSender, alertKey, - fmt.Sprintf(alertMessageTemplate, slackAlertSender.openConnectionsAlertThreshold)) - } - } else { - if openConnections.CurrentOpenConnections > int64(defaultOpenConnectionsThreshold) { - a.alertToSlack(ctx, slackAlertSender, alertKey, - fmt.Sprintf(alertMessageTemplate, defaultOpenConnectionsThreshold)) + ` has exceeded threshold size of %%d connections, currently at %d connections!`, + deploymentUIDPrefix, openConnections.UserName, peerName, openConnections.CurrentOpenConnections) + + if openConnections.CurrentOpenConnections > int64(lowestOpenConnectionsThreshold) { + for _, alertSenderConfig := range alertSenderConfigs { + if a.checkAndAddAlertToCatalog(ctx, + alertSenderConfig.Id, alertKey, fmt.Sprintf(alertMessageTemplate, lowestOpenConnectionsThreshold)) { + if alertSenderConfig.Sender.getOpenConnectionsAlertThreshold() > 0 { + if openConnections.CurrentOpenConnections > int64(alertSenderConfig.Sender.getOpenConnectionsAlertThreshold()) { + a.alertToProvider(ctx, alertSenderConfig, alertKey, + fmt.Sprintf(alertMessageTemplate, alertSenderConfig.Sender.getOpenConnectionsAlertThreshold())) + } + } else { + if openConnections.CurrentOpenConnections > int64(defaultOpenConnectionsThreshold) { + a.alertToProvider(ctx, alertSenderConfig, alertKey, + fmt.Sprintf(alertMessageTemplate, defaultOpenConnectionsThreshold)) + } } } } } } -func (a *Alerter) alertToSlack(ctx context.Context, slackAlertSender *slackAlertSender, alertKey string, alertMessage string) { - err := slackAlertSender.sendAlert(ctx, - ":rotating_light:Alert:rotating_light:: "+alertKey, alertMessage) +func (a *Alerter) alertToProvider(ctx context.Context, alertSenderConfig AlertSenderConfig, alertKey string, alertMessage string) { + err := alertSenderConfig.Sender.sendAlert(ctx, alertKey, alertMessage) if err != nil { logger.LoggerFromCtx(ctx).Warn("failed to send alert", slog.Any("error", err)) return @@ -176,7 +215,7 @@ func (a *Alerter) alertToSlack(ctx context.Context, slackAlertSender *slackAlert // Only raises an alert if another alert with the same key hasn't been raised // in the past X minutes, where X is configurable and defaults to 15 minutes // returns true if alert added to catalog, so proceed with processing alerts to slack -func (a *Alerter) checkAndAddAlertToCatalog(ctx context.Context, alertKey string, alertMessage string) bool { +func (a *Alerter) checkAndAddAlertToCatalog(ctx context.Context, alertConfigId int64, alertKey string, alertMessage string) bool { dur := dynamicconf.PeerDBAlertingGapMinutesAsDuration(ctx) if dur == 0 { logger.LoggerFromCtx(ctx).Warn("Alerting disabled via environment variable, returning") @@ -184,9 +223,9 @@ func (a *Alerter) checkAndAddAlertToCatalog(ctx context.Context, alertKey string } row := a.catalogPool.QueryRow(ctx, - `SELECT created_timestamp FROM peerdb_stats.alerts_v1 WHERE alert_key=$1 + `SELECT created_timestamp FROM peerdb_stats.alerts_v1 WHERE alert_key=$1 AND alert_config_id=$2 ORDER BY created_timestamp DESC LIMIT 1`, - alertKey) + alertKey, alertConfigId) var createdTimestamp time.Time err := row.Scan(&createdTimestamp) if err != nil && err != pgx.ErrNoRows { @@ -196,14 +235,18 @@ func (a *Alerter) checkAndAddAlertToCatalog(ctx context.Context, alertKey string if time.Since(createdTimestamp) >= dur { _, err = a.catalogPool.Exec(ctx, - "INSERT INTO peerdb_stats.alerts_v1(alert_key,alert_message) VALUES($1,$2)", - alertKey, alertMessage) + "INSERT INTO peerdb_stats.alerts_v1(alert_key,alert_message,alert_config_id) VALUES($1,$2,$3)", + alertKey, alertMessage, alertConfigId) if err != nil { logger.LoggerFromCtx(ctx).Warn("failed to insert alert", slog.Any("error", err)) return false } return true } + + logger.LoggerFromCtx(ctx).Info( + fmt.Sprintf("Skipped sending alerts: last alert was sent at %s, which was >=%s ago", + createdTimestamp.String(), dur.String())) return false } diff --git a/flow/alerting/email_alert_sender.go b/flow/alerting/email_alert_sender.go new file mode 100644 index 0000000000..2a534318c0 --- /dev/null +++ b/flow/alerting/email_alert_sender.go @@ -0,0 +1,104 @@ +package alerting + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ses" + "github.com/aws/aws-sdk-go-v2/service/ses/types" + + "github.com/PeerDB-io/peer-flow/logger" + "github.com/PeerDB-io/peer-flow/peerdbenv" + "github.com/PeerDB-io/peer-flow/shared/aws_common" +) + +type EmailAlertSender struct { + AlertSender + client *ses.Client + sourceEmail string + configurationSetName string + replyToAddresses []string + slotLagMBAlertThreshold uint32 + openConnectionsAlertThreshold uint32 + emailAddresses []string +} + +func (e *EmailAlertSender) getSlotLagMBAlertThreshold() uint32 { + return e.slotLagMBAlertThreshold +} + +func (e *EmailAlertSender) getOpenConnectionsAlertThreshold() uint32 { + return e.openConnectionsAlertThreshold +} + +type EmailAlertSenderConfig struct { + sourceEmail string + configurationSetName string + replyToAddresses []string + SlotLagMBAlertThreshold uint32 `json:"slot_lag_mb_alert_threshold"` + OpenConnectionsAlertThreshold uint32 `json:"open_connections_alert_threshold"` + EmailAddresses []string `json:"email_addresses"` +} + +func (e *EmailAlertSender) sendAlert(ctx context.Context, alertTitle string, alertMessage string) error { + _, err := e.client.SendEmail(ctx, &ses.SendEmailInput{ + Destination: &types.Destination{ + ToAddresses: e.emailAddresses, + }, + Message: &types.Message{ + Body: &types.Body{ + Text: &types.Content{ + Data: aws.String(alertMessage), + Charset: aws.String("utf-8"), + }, + }, + Subject: &types.Content{ + Data: aws.String(alertTitle), + Charset: aws.String("utf-8"), + }, + }, + Source: aws.String(e.sourceEmail), + ConfigurationSetName: aws.String(e.configurationSetName), + ReplyToAddresses: e.replyToAddresses, + Tags: []types.MessageTag{ + {Name: aws.String("DeploymentUUID"), Value: aws.String(peerdbenv.PeerDBDeploymentUID())}, + }, + }) + if err != nil { + logger.LoggerFromCtx(ctx).Warn(fmt.Sprintf( + "Error sending email alert from %v to %s subject=[%s], body=[%s], configurationSet=%s, replyToAddresses=[%v]", + e.sourceEmail, e.emailAddresses, alertTitle, alertMessage, e.configurationSetName, e.replyToAddresses)) + return err + } + return nil +} + +func NewEmailAlertSenderWithNewClient(ctx context.Context, region *string, config *EmailAlertSenderConfig) (*EmailAlertSender, error) { + client, err := newSesClient(ctx, region) + if err != nil { + return nil, err + } + return NewEmailAlertSender(client, config), nil +} + +func NewEmailAlertSender(client *ses.Client, config *EmailAlertSenderConfig) *EmailAlertSender { + return &EmailAlertSender{ + client: client, + sourceEmail: config.sourceEmail, + configurationSetName: config.configurationSetName, + replyToAddresses: config.replyToAddresses, + slotLagMBAlertThreshold: config.SlotLagMBAlertThreshold, + openConnectionsAlertThreshold: config.OpenConnectionsAlertThreshold, + emailAddresses: config.EmailAddresses, + } +} + +func newSesClient(ctx context.Context, region *string) (*ses.Client, error) { + sdkConfig, err := aws_common.LoadSdkConfig(ctx, region) + if err != nil { + return nil, err + } + snsClient := ses.NewFromConfig(*sdkConfig) + return snsClient, nil +} diff --git a/flow/alerting/interface.go b/flow/alerting/interface.go new file mode 100644 index 0000000000..a9dd2d51b9 --- /dev/null +++ b/flow/alerting/interface.go @@ -0,0 +1,9 @@ +package alerting + +import "context" + +type AlertSender interface { + sendAlert(ctx context.Context, alertTitle string, alertMessage string) error + getSlotLagMBAlertThreshold() uint32 + getOpenConnectionsAlertThreshold() uint32 +} diff --git a/flow/alerting/slack_alert_sender.go b/flow/alerting/slack_alert_sender.go index 1ad007536b..85f0657d11 100644 --- a/flow/alerting/slack_alert_sender.go +++ b/flow/alerting/slack_alert_sender.go @@ -7,13 +7,22 @@ import ( "github.com/slack-go/slack" ) -type slackAlertSender struct { +type SlackAlertSender struct { + AlertSender client *slack.Client channelIDs []string slotLagMBAlertThreshold uint32 openConnectionsAlertThreshold uint32 } +func (s *SlackAlertSender) getSlotLagMBAlertThreshold() uint32 { + return s.slotLagMBAlertThreshold +} + +func (s *SlackAlertSender) getOpenConnectionsAlertThreshold() uint32 { + return s.openConnectionsAlertThreshold +} + type slackAlertConfig struct { AuthToken string `json:"auth_token"` ChannelIDs []string `json:"channel_ids"` @@ -21,8 +30,8 @@ type slackAlertConfig struct { OpenConnectionsAlertThreshold uint32 `json:"open_connections_alert_threshold"` } -func newSlackAlertSender(config *slackAlertConfig) *slackAlertSender { - return &slackAlertSender{ +func newSlackAlertSender(config *slackAlertConfig) *SlackAlertSender { + return &SlackAlertSender{ client: slack.New(config.AuthToken), channelIDs: config.ChannelIDs, slotLagMBAlertThreshold: config.SlotLagMBAlertThreshold, @@ -30,11 +39,11 @@ func newSlackAlertSender(config *slackAlertConfig) *slackAlertSender { } } -func (s *slackAlertSender) sendAlert(ctx context.Context, alertTitle string, alertMessage string) error { +func (s *SlackAlertSender) sendAlert(ctx context.Context, alertTitle string, alertMessage string) error { for _, channelID := range s.channelIDs { _, _, _, err := s.client.SendMessageContext(ctx, channelID, slack.MsgOptionBlocks( - slack.NewHeaderBlock(slack.NewTextBlockObject("plain_text", alertTitle, true, false)), - slack.NewSectionBlock(slack.NewTextBlockObject("mrkdwn", alertMessage, false, false), nil, nil), + slack.NewHeaderBlock(slack.NewTextBlockObject("plain_text", ":rotating_light:Alert:rotating_light:: "+alertTitle, true, false)), + slack.NewSectionBlock(slack.NewTextBlockObject("mrkdwn", alertMessage+"\ncc: ", false, false), nil, nil), )) if err != nil { return fmt.Errorf("failed to send message to Slack channel %s: %w", channelID, err) diff --git a/flow/alerting/types.go b/flow/alerting/types.go new file mode 100644 index 0000000000..6277010ee0 --- /dev/null +++ b/flow/alerting/types.go @@ -0,0 +1,8 @@ +package alerting + +type ServiceType string + +const ( + SLACK ServiceType = "slack" + EMAIL ServiceType = "email" +) diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 9c86baef50..31ab25c27f 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -28,11 +28,16 @@ type Connector interface { ConnectionActive(context.Context) error } -type CDCPullConnector interface { +type GetTableSchemaConnector interface { Connector // GetTableSchema returns the schema of a table. GetTableSchema(ctx context.Context, req *protos.GetTableSchemaBatchInput) (*protos.GetTableSchemaBatchOutput, error) +} + +type CDCPullConnector interface { + Connector + GetTableSchemaConnector // EnsurePullability ensures that the connector is pullable. EnsurePullability(ctx context.Context, req *protos.EnsurePullabilityBatchInput) ( @@ -55,6 +60,9 @@ type CDCPullConnector interface { // This method should be idempotent, and should be able to be called multiple times with the same request. PullRecords(ctx context.Context, catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error + // Called when offset has been confirmed to destination + UpdateReplStateLastOffset(lastOffset int64) + // PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR PullFlowCleanup(ctx context.Context, jobName string) error @@ -256,6 +264,9 @@ var ( _ CDCNormalizeConnector = &connsnowflake.SnowflakeConnector{} _ CDCNormalizeConnector = &connclickhouse.ClickhouseConnector{} + _ GetTableSchemaConnector = &connpostgres.PostgresConnector{} + _ GetTableSchemaConnector = &connsnowflake.SnowflakeConnector{} + _ NormalizedTablesConnector = &connpostgres.PostgresConnector{} _ NormalizedTablesConnector = &connbigquery.BigQueryConnector{} _ NormalizedTablesConnector = &connsnowflake.SnowflakeConnector{} diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 22ac244aa9..898f962aec 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -458,7 +458,7 @@ func generateCreateTableSQLForNormalizedTable( } // add composite primary key to the table - if len(sourceTableSchema.PrimaryKeyColumns) > 0 { + if len(sourceTableSchema.PrimaryKeyColumns) > 0 && !sourceTableSchema.IsReplicaIdentityFull { primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol)) diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 65b3be2512..3ea75c17d0 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -8,6 +8,7 @@ import ( "regexp" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -49,6 +50,7 @@ type ReplState struct { Slot string Publication string Offset int64 + LastOffset atomic.Int64 } func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { @@ -133,7 +135,7 @@ func (c *PostgresConnector) ReplPing(ctx context.Context) error { return pglogrepl.SendStandbyStatusUpdate( ctx, c.replConn.PgConn(), - pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(c.replState.Offset)}, + pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(c.replState.LastOffset.Load())}, ) } } @@ -184,7 +186,9 @@ func (c *PostgresConnector) MaybeStartReplication( Slot: slotName, Publication: publicationName, Offset: req.LastOffset, + LastOffset: atomic.Int64{}, } + c.replState.LastOffset.Store(req.LastOffset) } return nil } @@ -308,6 +312,9 @@ func (c *PostgresConnector) SetLastOffset(ctx context.Context, jobName string, l func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error { defer func() { req.RecordStream.Close() + if c.replState != nil { + c.replState.Offset = req.RecordStream.GetLastCheckpoint() + } }() // Slotname would be the job name prefixed with "peerflow_slot_" @@ -371,9 +378,6 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo return err } - req.RecordStream.Close() - c.replState.Offset = req.RecordStream.GetLastCheckpoint() - latestLSN, err := c.getCurrentLSN(ctx) if err != nil { c.logger.Error("error getting current LSN", slog.Any("error", err)) @@ -389,6 +393,12 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo return nil } +func (c *PostgresConnector) UpdateReplStateLastOffset(lastOffset int64) { + if c.replState != nil { + c.replState.LastOffset.Store(lastOffset) + } +} + // SyncRecords pushes records to the destination. func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) diff --git a/flow/connectors/snowflake/get_schema_for_tests.go b/flow/connectors/snowflake/get_schema_for_tests.go index 05631e635f..476f16f165 100644 --- a/flow/connectors/snowflake/get_schema_for_tests.go +++ b/flow/connectors/snowflake/get_schema_for_tests.go @@ -3,7 +3,6 @@ package connsnowflake import ( "context" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" ) @@ -47,7 +46,6 @@ func (c *SnowflakeConnector) GetTableSchema( return nil, err } res[tableName] = tableSchema - utils.RecordHeartbeat(ctx, "fetched schema for table "+tableName) } return &protos.GetTableSchemaBatchOutput{ diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 775c130665..4ba626a053 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -727,7 +727,7 @@ func generateCreateTableSQLForNormalizedTable( } // add composite primary key to the table - if len(sourceTableSchema.PrimaryKeyColumns) > 0 { + if len(sourceTableSchema.PrimaryKeyColumns) > 0 && !sourceTableSchema.IsReplicaIdentityFull { normalizedPrimaryKeyCols := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { normalizedPrimaryKeyCols = append(normalizedPrimaryKeyCols, diff --git a/flow/e2e/bigquery/bigquery.go b/flow/e2e/bigquery/bigquery.go index 73f5c38d6e..1e2a3842ee 100644 --- a/flow/e2e/bigquery/bigquery.go +++ b/flow/e2e/bigquery/bigquery.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5" + "github.com/PeerDB-io/peer-flow/connectors" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -35,6 +36,11 @@ func (s PeerFlowE2ETestSuiteBQ) Connector() *connpostgres.PostgresConnector { return s.conn } +func (s PeerFlowE2ETestSuiteBQ) DestinationConnector() connectors.Connector { + // TODO have BQ connector + return nil +} + func (s PeerFlowE2ETestSuiteBQ) Suffix() string { return s.bqSuffix } diff --git a/flow/e2e/bigquery/peer_flow_bq_test.go b/flow/e2e/bigquery/peer_flow_bq_test.go index ec28b5f97b..3dff6e310e 100644 --- a/flow/e2e/bigquery/peer_flow_bq_test.go +++ b/flow/e2e/bigquery/peer_flow_bq_test.go @@ -643,7 +643,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Multi_Table_BQ() { e2e.RequireEnvCanceled(s.t, env) } -// TODO: not checking schema exactly, add later +// TODO: not checking schema exactly +// write a GetTableSchemaConnector for BQ to enable generic_test func (s PeerFlowE2ETestSuiteBQ) Test_Simple_Schema_Changes_BQ() { tc := e2e.NewTemporalClient(s.t) diff --git a/flow/e2e/generic/generic_test.go b/flow/e2e/generic/generic_test.go new file mode 100644 index 0000000000..97e64f5ed3 --- /dev/null +++ b/flow/e2e/generic/generic_test.go @@ -0,0 +1,304 @@ +package e2e_generic + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peer-flow/connectors" + "github.com/PeerDB-io/peer-flow/e2e" + "github.com/PeerDB-io/peer-flow/e2e/bigquery" + "github.com/PeerDB-io/peer-flow/e2e/postgres" + "github.com/PeerDB-io/peer-flow/e2e/snowflake" + "github.com/PeerDB-io/peer-flow/e2eshared" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + peerflow "github.com/PeerDB-io/peer-flow/workflows" +) + +func TestGenericPG(t *testing.T) { + e2eshared.RunSuite(t, SetupGenericSuite(e2e_postgres.SetupSuite)) +} + +func TestGenericSF(t *testing.T) { + e2eshared.RunSuite(t, SetupGenericSuite(e2e_snowflake.SetupSuite)) +} + +func TestGenericBQ(t *testing.T) { + e2eshared.RunSuite(t, SetupGenericSuite(e2e_bigquery.SetupSuite)) +} + +type Generic struct { + e2e.GenericSuite +} + +func SetupGenericSuite[T e2e.GenericSuite](f func(t *testing.T) T) func(t *testing.T) Generic { + return func(t *testing.T) Generic { + t.Helper() + return Generic{f(t)} + } +} + +func (s Generic) Test_Simple_Flow() { + t := s.T() + srcTable := "test_simple" + dstTable := "test_simple_dst" + srcSchemaTable := e2e.AttachSchema(s, srcTable) + + _, err := s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + key TEXT NOT NULL, + value TEXT NOT NULL, + myh HSTORE NOT NULL + ); + `, srcSchemaTable)) + require.NoError(t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, "test_simple"), + TableMappings: e2e.TableMappings(s, srcTable, dstTable), + Destination: s.Peer(), + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + + tc := e2e.NewTemporalClient(t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + + e2e.SetupCDCFlowStatusQuery(t, env, connectionGen) + // insert 10 rows into the source table + for i := range 10 { + testKey := fmt.Sprintf("test_key_%d", i) + testValue := fmt.Sprintf("test_value_%d", i) + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(key, value, myh) VALUES ($1, $2, '"a"=>"b"') + `, srcSchemaTable), testKey, testValue) + e2e.EnvNoError(t, env, err) + } + t.Log("Inserted 10 rows into the source table") + + e2e.EnvWaitForEqualTablesWithNames(env, s, "normalizing 10 rows", srcTable, dstTable, `id,key,value,myh`) + env.Cancel() + e2e.RequireEnvCanceled(t, env) +} + +func (s Generic) Test_Simple_Schema_Changes() { + t := s.T() + + destinationSchemaConnector, ok := s.DestinationConnector().(connectors.GetTableSchemaConnector) + if !ok { + t.SkipNow() + } + + srcTable := "test_simple_schema_changes" + dstTable := "test_simple_schema_changes_dst" + srcTableName := e2e.AttachSchema(s, srcTable) + dstTableName := s.DestinationTable(dstTable) + + _, err := s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 BIGINT + ); + `, srcTableName)) + require.NoError(t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, srcTable), + TableMappings: e2e.TableMappings(s, srcTable, dstTable), + Destination: s.Peer(), + } + + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + + // wait for PeerFlowStatusQuery to finish setup + // and then insert and mutate schema repeatedly. + tc := e2e.NewTemporalClient(t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(t, env, connectionGen) + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 1) + e2e.EnvNoError(t, env, err) + t.Log("Inserted initial row in the source table") + + e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize reinsert", srcTable, dstTable, "id,c1") + + expectedTableSchema := &protos.TableSchema{ + TableIdentifier: e2e.ExpectedDestinationTableName(s, dstTable), + Columns: []*protos.FieldDescription{ + { + Name: e2e.ExpectedDestinationIdentifier(s, "id"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c1"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: "_PEERDB_IS_DELETED", + Type: string(qvalue.QValueKindBoolean), + TypeModifier: -1, + }, + { + Name: "_PEERDB_SYNCED_AT", + Type: string(qvalue.QValueKindTimestamp), + TypeModifier: -1, + }, + }, + } + output, err := destinationSchemaConnector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ + TableIdentifiers: []string{dstTableName}, + }) + e2e.EnvNoError(t, env, err) + e2e.EnvTrue(t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) + + // alter source table, add column c2 and insert another row. + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) + e2e.EnvNoError(t, env, err) + t.Log("Altered source table, added column c2") + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,c2) VALUES ($1,$2)`, srcTableName), 2, 2) + e2e.EnvNoError(t, env, err) + t.Log("Inserted row with added c2 in the source table") + + // verify we got our two rows, if schema did not match up it will error. + e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize altered row", srcTable, dstTable, "id,c1,c2") + expectedTableSchema = &protos.TableSchema{ + TableIdentifier: e2e.ExpectedDestinationTableName(s, dstTable), + Columns: []*protos.FieldDescription{ + { + Name: e2e.ExpectedDestinationIdentifier(s, "id"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c1"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: "_PEERDB_SYNCED_AT", + Type: string(qvalue.QValueKindTimestamp), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c2"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + }, + } + output, err = destinationSchemaConnector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ + TableIdentifiers: []string{dstTableName}, + }) + e2e.EnvNoError(t, env, err) + e2e.EnvTrue(t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) + e2e.EnvEqualTablesWithNames(env, s, srcTable, dstTable, "id,c1,c2") + + // alter source table, add column c3, drop column c2 and insert another row. + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) + e2e.EnvNoError(t, env, err) + t.Log("Altered source table, dropped column c2 and added column c3") + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,c3) VALUES ($1,$2)`, srcTableName), 3, 3) + e2e.EnvNoError(t, env, err) + t.Log("Inserted row with added c3 in the source table") + + // verify we got our two rows, if schema did not match up it will error. + e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize dropped c2 column", srcTable, dstTable, "id,c1,c3") + expectedTableSchema = &protos.TableSchema{ + TableIdentifier: e2e.ExpectedDestinationTableName(s, dstTable), + Columns: []*protos.FieldDescription{ + { + Name: e2e.ExpectedDestinationIdentifier(s, "id"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c1"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: "_PEERDB_SYNCED_AT", + Type: string(qvalue.QValueKindTimestamp), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c2"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c3"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + }, + } + output, err = destinationSchemaConnector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ + TableIdentifiers: []string{dstTableName}, + }) + e2e.EnvNoError(t, env, err) + e2e.EnvTrue(t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) + e2e.EnvEqualTablesWithNames(env, s, srcTable, dstTable, "id,c1,c3") + + // alter source table, drop column c3 and insert another row. + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + ALTER TABLE %s DROP COLUMN c3`, srcTableName)) + e2e.EnvNoError(t, env, err) + t.Log("Altered source table, dropped column c3") + _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 4) + e2e.EnvNoError(t, env, err) + t.Log("Inserted row after dropping all columns in the source table") + + // verify we got our two rows, if schema did not match up it will error. + e2e.EnvWaitForEqualTablesWithNames(env, s, "normalize dropped c3 column", srcTable, dstTable, "id,c1") + expectedTableSchema = &protos.TableSchema{ + TableIdentifier: e2e.ExpectedDestinationTableName(s, dstTable), + Columns: []*protos.FieldDescription{ + { + Name: e2e.ExpectedDestinationIdentifier(s, "id"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c1"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: "_PEERDB_SYNCED_AT", + Type: string(qvalue.QValueKindTimestamp), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c2"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + { + Name: e2e.ExpectedDestinationIdentifier(s, "c3"), + Type: string(qvalue.QValueKindNumeric), + TypeModifier: -1, + }, + }, + } + output, err = destinationSchemaConnector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ + TableIdentifiers: []string{dstTableName}, + }) + e2e.EnvNoError(t, env, err) + e2e.EnvTrue(t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) + e2e.EnvEqualTablesWithNames(env, s, srcTable, dstTable, "id,c1") + + env.Cancel() + + e2e.RequireEnvCanceled(t, env) +} diff --git a/flow/e2e/generic/peer_flow_test.go b/flow/e2e/generic/peer_flow_test.go deleted file mode 100644 index 20c5847df4..0000000000 --- a/flow/e2e/generic/peer_flow_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package e2e_generic - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/PeerDB-io/peer-flow/e2e" - "github.com/PeerDB-io/peer-flow/e2e/bigquery" - "github.com/PeerDB-io/peer-flow/e2e/postgres" - "github.com/PeerDB-io/peer-flow/e2e/snowflake" - "github.com/PeerDB-io/peer-flow/e2eshared" - peerflow "github.com/PeerDB-io/peer-flow/workflows" -) - -func TestGenericPG(t *testing.T) { - e2eshared.RunSuite(t, SetupGenericSuite(e2e_postgres.SetupSuite)) -} - -func TestGenericSF(t *testing.T) { - e2eshared.RunSuite(t, SetupGenericSuite(e2e_snowflake.SetupSuite)) -} - -func TestGenericBQ(t *testing.T) { - e2eshared.RunSuite(t, SetupGenericSuite(e2e_bigquery.SetupSuite)) -} - -type Generic struct { - e2e.GenericSuite -} - -func SetupGenericSuite[T e2e.GenericSuite](f func(t *testing.T) T) func(t *testing.T) Generic { - return func(t *testing.T) Generic { - t.Helper() - return Generic{f(t)} - } -} - -func (s Generic) Test_Simple_Flow() { - t := s.T() - srcTable := "test_simple" - dstTable := "test_simple_dst" - srcSchemaTable := e2e.AttachSchema(s, srcTable) - - _, err := s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id SERIAL PRIMARY KEY, - key TEXT NOT NULL, - value TEXT NOT NULL, - myh HSTORE NOT NULL - ); - `, srcSchemaTable)) - require.NoError(t, err) - - connectionGen := e2e.FlowConnectionGenerationConfig{ - FlowJobName: e2e.AddSuffix(s, "test_simple"), - TableMappings: e2e.TableMappings(s, srcTable, dstTable), - Destination: s.Peer(), - } - flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() - - tc := e2e.NewTemporalClient(t) - env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) - - e2e.SetupCDCFlowStatusQuery(t, env, connectionGen) - // insert 10 rows into the source table - for i := range 10 { - testKey := fmt.Sprintf("test_key_%d", i) - testValue := fmt.Sprintf("test_value_%d", i) - _, err = s.Connector().Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(key, value, myh) VALUES ($1, $2, '"a"=>"b"') - `, srcSchemaTable), testKey, testValue) - e2e.EnvNoError(t, env, err) - } - t.Log("Inserted 10 rows into the source table") - - e2e.EnvWaitForEqualTablesWithNames(env, s, "normalizing 10 rows", srcTable, dstTable, `id,key,value,myh`) - env.Cancel() - e2e.RequireEnvCanceled(t, env) -} diff --git a/flow/e2e/postgres/peer_flow_pg_test.go b/flow/e2e/postgres/peer_flow_pg_test.go index 2e69376b01..eeec5e373d 100644 --- a/flow/e2e/postgres/peer_flow_pg_test.go +++ b/flow/e2e/postgres/peer_flow_pg_test.go @@ -17,7 +17,6 @@ import ( "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" peerflow "github.com/PeerDB-io/peer-flow/workflows" ) @@ -51,32 +50,6 @@ func (s PeerFlowE2ETestSuitePG) checkPeerdbColumns(dstSchemaQualified string, ro return nil } -func (s PeerFlowE2ETestSuitePG) WaitForSchema( - env e2e.WorkflowRun, - reason string, - srcTableName string, - dstTableName string, - cols string, - expectedSchema *protos.TableSchema, -) { - s.t.Helper() - e2e.EnvWaitFor(s.t, env, 3*time.Minute, reason, func() bool { - s.t.Helper() - output, err := s.conn.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{dstTableName}, - }) - if err != nil { - return false - } - tableSchema := output.TableNameSchemaMapping[dstTableName] - if !e2e.CompareTableSchemas(expectedSchema, tableSchema) { - s.t.Log("schemas unequal", expectedSchema, tableSchema) - return false - } - return s.comparePGTables(srcTableName, dstTableName, cols) == nil - }) -} - func (s PeerFlowE2ETestSuitePG) Test_Geospatial_PG() { srcTableName := s.attachSchemaSuffix("test_geospatial_pg") dstTableName := s.attachSchemaSuffix("test_geospatial_pg_dst") @@ -224,188 +197,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Enums_PG() { e2e.RequireEnvCanceled(s.t, env) } -func (s PeerFlowE2ETestSuitePG) Test_Simple_Schema_Changes_PG() { - tc := e2e.NewTemporalClient(s.t) - - srcTableName := s.attachSchemaSuffix("test_simple_schema_changes") - dstTableName := s.attachSchemaSuffix("test_simple_schema_changes_dst") - - _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - c1 BIGINT - ); - `, srcTableName)) - require.NoError(s.t, err) - - connectionGen := e2e.FlowConnectionGenerationConfig{ - FlowJobName: s.attachSuffix("test_simple_schema_changes"), - TableNameMapping: map[string]string{srcTableName: dstTableName}, - Destination: s.peer, - } - - flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() - flowConnConfig.MaxBatchSize = 1 - - // wait for PeerFlowStatusQuery to finish setup - // and then insert and mutate schema repeatedly. - env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) - e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - - // insert first row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 1) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted initial row in the source table") - - s.WaitForSchema(env, "normalizing first row", srcTableName, dstTableName, "id,c1", &protos.TableSchema{ - TableIdentifier: dstTableName, - PrimaryKeyColumns: []string{"id"}, - Columns: []*protos.FieldDescription{ - { - Name: "id", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c1", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - }, - }) - - // alter source table, add column c2 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, added column c2") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c2) VALUES ($1,$2)`, srcTableName), 2, 2) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row with added c2 in the source table") - - s.WaitForSchema(env, "normalizing altered row", srcTableName, dstTableName, "id,c1,c2", &protos.TableSchema{ - TableIdentifier: dstTableName, - PrimaryKeyColumns: []string{"id"}, - Columns: []*protos.FieldDescription{ - { - Name: "id", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c1", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "c2", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - }, - }) - - // alter source table, add column c3, drop column c2 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, dropped column c2 and added column c3") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c3) VALUES ($1,$2)`, srcTableName), 3, 3) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row with added c3 in the source table") - - s.WaitForSchema(env, "normalizing dropped column row", srcTableName, dstTableName, "id,c1,c3", &protos.TableSchema{ - TableIdentifier: dstTableName, - PrimaryKeyColumns: []string{"id"}, - Columns: []*protos.FieldDescription{ - { - Name: "id", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c1", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c2", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "c3", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - }, - }) - - // alter source table, drop column c3 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s DROP COLUMN c3`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, dropped column c3") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 4) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row after dropping all columns in the source table") - - s.WaitForSchema(env, "normalizing 2nd dropped column row", srcTableName, dstTableName, "id,c1", &protos.TableSchema{ - TableIdentifier: dstTableName, - PrimaryKeyColumns: []string{"id"}, - Columns: []*protos.FieldDescription{ - { - Name: "id", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c1", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "c2", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - { - Name: "c3", - Type: string(qvalue.QValueKindInt64), - TypeModifier: -1, - }, - }, - }) - - env.Cancel() - - e2e.RequireEnvCanceled(s.t, env) -} - func (s PeerFlowE2ETestSuitePG) Test_Composite_PKey_PG() { tc := e2e.NewTemporalClient(s.t) diff --git a/flow/e2e/postgres/postgres.go b/flow/e2e/postgres/postgres.go index 23ca778c8d..8eafd6ade0 100644 --- a/flow/e2e/postgres/postgres.go +++ b/flow/e2e/postgres/postgres.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" + "github.com/PeerDB-io/peer-flow/connectors" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -32,6 +33,10 @@ func (s PeerFlowE2ETestSuitePG) Connector() *connpostgres.PostgresConnector { return s.conn } +func (s PeerFlowE2ETestSuitePG) DestinationConnector() connectors.Connector { + return s.conn +} + func (s PeerFlowE2ETestSuitePG) Conn() *pgx.Conn { return s.conn.Conn() } diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 525d2c7256..56084a1a27 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strings" "testing" "time" @@ -16,7 +15,6 @@ import ( "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" - "github.com/PeerDB-io/peer-flow/model/qvalue" peerflow "github.com/PeerDB-io/peer-flow/workflows" ) @@ -516,218 +514,6 @@ func (s PeerFlowE2ETestSuiteSF) Test_Multi_Table_SF() { e2e.RequireEnvCanceled(s.t, env) } -func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { - tc := e2e.NewTemporalClient(s.t) - - srcTableName := s.attachSchemaSuffix("test_simple_schema_changes") - dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, "test_simple_schema_changes") - - _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - c1 BIGINT - ); - `, srcTableName)) - require.NoError(s.t, err) - - connectionGen := e2e.FlowConnectionGenerationConfig{ - FlowJobName: s.attachSuffix("test_simple_schema_changes"), - TableNameMapping: map[string]string{srcTableName: dstTableName}, - Destination: s.sfHelper.Peer, - } - - flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() - flowConnConfig.MaxBatchSize = 100 - - // wait for PeerFlowStatusQuery to finish setup - // and then insert and mutate schema repeatedly. - env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) - e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 1) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted initial row in the source table") - - e2e.EnvWaitForEqualTables(env, s, "normalize reinsert", "test_simple_schema_changes", "id,c1") - - expectedTableSchema := &protos.TableSchema{ - TableIdentifier: strings.ToUpper(dstTableName), - Columns: []*protos.FieldDescription{ - { - Name: "ID", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C1", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "_PEERDB_IS_DELETED", - Type: string(qvalue.QValueKindBoolean), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - }, - } - output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{dstTableName}, - }) - e2e.EnvNoError(s.t, env, err) - e2e.EnvTrue(s.t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) - - // alter source table, add column c2 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s ADD COLUMN c2 BIGINT`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, added column c2") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c2) VALUES ($1,$2)`, srcTableName), 2, 2) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row with added c2 in the source table") - - // verify we got our two rows, if schema did not match up it will error. - e2e.EnvWaitForEqualTables(env, s, "normalize altered row", "test_simple_schema_changes", "id,c1,c2") - expectedTableSchema = &protos.TableSchema{ - TableIdentifier: strings.ToUpper(dstTableName), - Columns: []*protos.FieldDescription{ - { - Name: "ID", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C1", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "C2", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - }, - } - output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{dstTableName}, - }) - e2e.EnvNoError(s.t, env, err) - e2e.EnvTrue(s.t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) - e2e.EnvEqualTables(env, s, "test_simple_schema_changes", "id,c1,c2") - - // alter source table, add column c3, drop column c2 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s DROP COLUMN c2, ADD COLUMN c3 BIGINT`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, dropped column c2 and added column c3") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c3) VALUES ($1,$2)`, srcTableName), 3, 3) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row with added c3 in the source table") - - // verify we got our two rows, if schema did not match up it will error. - e2e.EnvWaitForEqualTables(env, s, "normalize dropped c2 column", "test_simple_schema_changes", "id,c1,c3") - expectedTableSchema = &protos.TableSchema{ - TableIdentifier: strings.ToUpper(dstTableName), - Columns: []*protos.FieldDescription{ - { - Name: "ID", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C1", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "C2", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C3", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - }, - } - output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{dstTableName}, - }) - e2e.EnvNoError(s.t, env, err) - e2e.EnvTrue(s.t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) - e2e.EnvEqualTables(env, s, "test_simple_schema_changes", "id,c1,c3") - - // alter source table, drop column c3 and insert another row. - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - ALTER TABLE %s DROP COLUMN c3`, srcTableName)) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Altered source table, dropped column c3") - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1) VALUES ($1)`, srcTableName), 4) - e2e.EnvNoError(s.t, env, err) - s.t.Log("Inserted row after dropping all columns in the source table") - - // verify we got our two rows, if schema did not match up it will error. - e2e.EnvWaitForEqualTables(env, s, "normalize dropped c3 column", "test_simple_schema_changes", "id,c1") - expectedTableSchema = &protos.TableSchema{ - TableIdentifier: strings.ToUpper(dstTableName), - Columns: []*protos.FieldDescription{ - { - Name: "ID", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C1", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "_PEERDB_SYNCED_AT", - Type: string(qvalue.QValueKindTimestamp), - TypeModifier: -1, - }, - { - Name: "C2", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - { - Name: "C3", - Type: string(qvalue.QValueKindNumeric), - TypeModifier: -1, - }, - }, - } - output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{dstTableName}, - }) - e2e.EnvNoError(s.t, env, err) - e2e.EnvTrue(s.t, env, e2e.CompareTableSchemas(expectedTableSchema, output.TableNameSchemaMapping[dstTableName])) - e2e.EnvEqualTables(env, s, "test_simple_schema_changes", "id,c1") - - env.Cancel() - - e2e.RequireEnvCanceled(s.t, env) -} - func (s PeerFlowE2ETestSuiteSF) Test_Composite_PKey_SF() { tc := e2e.NewTemporalClient(s.t) diff --git a/flow/e2e/snowflake/snowflake.go b/flow/e2e/snowflake/snowflake.go index 45132ef601..06c46d1046 100644 --- a/flow/e2e/snowflake/snowflake.go +++ b/flow/e2e/snowflake/snowflake.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" + "github.com/PeerDB-io/peer-flow/connectors" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/e2e" @@ -35,6 +36,10 @@ func (s PeerFlowE2ETestSuiteSF) Connector() *connpostgres.PostgresConnector { return s.conn } +func (s PeerFlowE2ETestSuiteSF) DestinationConnector() connectors.Connector { + return s.connector +} + func (s PeerFlowE2ETestSuiteSF) Conn() *pgx.Conn { return s.Connector().Conn() } diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index c018f32df2..dcaef74291 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -22,6 +22,7 @@ import ( "go.temporal.io/sdk/converter" "go.temporal.io/sdk/temporal" + "github.com/PeerDB-io/peer-flow/connectors" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -55,6 +56,7 @@ type RowSource interface { type GenericSuite interface { RowSource Peer() *protos.Peer + DestinationConnector() connectors.Connector DestinationTable(table string) string } @@ -112,13 +114,17 @@ func RequireEqualTables(suite RowSource, table string, cols string) { } func EnvEqualTables(env WorkflowRun, suite RowSource, table string, cols string) { + EnvEqualTablesWithNames(env, suite, table, table, cols) +} + +func EnvEqualTablesWithNames(env WorkflowRun, suite RowSource, srcTable string, dstTable string, cols string) { t := suite.T() t.Helper() - pgRows, err := GetPgRows(suite.Connector(), suite.Suffix(), table, cols) + pgRows, err := GetPgRows(suite.Connector(), suite.Suffix(), srcTable, cols) EnvNoError(t, env, err) - rows, err := suite.GetRows(table, cols) + rows, err := suite.GetRows(dstTable, cols) EnvNoError(t, env, err) EnvEqualRecordBatches(t, env, pgRows, rows) @@ -519,6 +525,19 @@ func GetOwnersSelectorStringsSF() [2]string { return [2]string{strings.Join(pgFields, ","), strings.Join(sfFields, ",")} } +func ExpectedDestinationIdentifier(s GenericSuite, ident string) string { + switch s.DestinationConnector().(type) { + case *connsnowflake.SnowflakeConnector: + return strings.ToUpper(ident) + default: + return ident + } +} + +func ExpectedDestinationTableName(s GenericSuite, table string) string { + return ExpectedDestinationIdentifier(s, s.DestinationTable(table)) +} + type testWriter struct { *testing.T } diff --git a/flow/go.mod b/flow/go.mod index 7b30930f2f..1235ae52b4 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -15,6 +15,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.7 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.9 github.com/aws/aws-sdk-go-v2/service/s3 v1.51.4 + github.com/aws/aws-sdk-go-v2/service/ses v1.22.2 github.com/aws/aws-sdk-go-v2/service/sns v1.29.2 github.com/cockroachdb/pebble v1.1.0 github.com/google/uuid v1.6.0 diff --git a/flow/go.sum b/flow/go.sum index 2f899957fe..4a3616a854 100644 --- a/flow/go.sum +++ b/flow/go.sum @@ -92,6 +92,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.3 h1:4t+QEX7BsXz98W github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.3/go.mod h1:oFcjjUq5Hm09N9rpxTdeMeLeQcxS7mIkBkL8qUKng+A= github.com/aws/aws-sdk-go-v2/service/s3 v1.51.4 h1:lW5xUzOPGAMY7HPuNF4FdyBwRc3UJ/e8KsapbesVeNU= github.com/aws/aws-sdk-go-v2/service/s3 v1.51.4/go.mod h1:MGTaf3x/+z7ZGugCGvepnx2DS6+caCYYqKhzVoLNYPk= +github.com/aws/aws-sdk-go-v2/service/ses v1.22.2 h1:cW5JtW23Lio3KDJ4l3jqRiOcCPKxJg7ooRA1SpIiuMo= +github.com/aws/aws-sdk-go-v2/service/ses v1.22.2/go.mod h1:MLj/NROJoperecxBME2zMN/O8Zrm0wv+6ah1Uqwpa1E= github.com/aws/aws-sdk-go-v2/service/sns v1.29.2 h1:kHm1SYs/NkxZpKINc4zOXOLJHVMzKtU4d7FlAMtDm50= github.com/aws/aws-sdk-go-v2/service/sns v1.29.2/go.mod h1:ZIs7/BaYel9NODoYa8PW39o15SFAXDEb4DxOG2It15U= github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 h1:XOPfar83RIRPEzfihnp+U6udOveKZJvPQ76SKWrLRHc= diff --git a/flow/peerdbenv/config.go b/flow/peerdbenv/config.go index 9a59bad5ef..4903bec091 100644 --- a/flow/peerdbenv/config.go +++ b/flow/peerdbenv/config.go @@ -95,3 +95,20 @@ func PeerDBEnableParallelSyncNormalize() bool { func PeerDBTelemetryAWSSNSTopicArn() string { return getEnvString("PEERDB_TELEMETRY_AWS_SNS_TOPIC_ARN", "") } + +func PeerDBAlertingEmailSenderSourceEmail() string { + return getEnvString("PEERDB_ALERTING_EMAIL_SENDER_SOURCE_EMAIL", "") +} + +func PeerDBAlertingEmailSenderConfigurationSet() string { + return getEnvString("PEERDB_ALERTING_EMAIL_SENDER_CONFIGURATION_SET", "") +} + +func PeerDBAlertingEmailSenderRegion() string { + return getEnvString("PEERDB_ALERTING_EMAIL_SENDER_REGION", "") +} + +// Comma-separated reply-to addresses +func PeerDBAlertingEmailSenderReplyToAddresses() string { + return getEnvString("PEERDB_ALERTING_EMAIL_SENDER_REPLY_TO_ADDRESSES", "") +} diff --git a/flow/shared/aws_common/config.go b/flow/shared/aws_common/config.go new file mode 100644 index 0000000000..6eced96f05 --- /dev/null +++ b/flow/shared/aws_common/config.go @@ -0,0 +1,21 @@ +package aws_common + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" +) + +func LoadSdkConfig(ctx context.Context, region *string) (*aws.Config, error) { + sdkConfig, err := config.LoadDefaultConfig(ctx, func(options *config.LoadOptions) error { + if region != nil { + options.Region = *region + } + return nil + }) + if err != nil { + return nil, err + } + return &sdkConfig, nil +} diff --git a/flow/shared/telemetry/sns_message_sender.go b/flow/shared/telemetry/sns_message_sender.go index 9d32dcf8f9..218b693b38 100644 --- a/flow/shared/telemetry/sns_message_sender.go +++ b/flow/shared/telemetry/sns_message_sender.go @@ -8,10 +8,11 @@ import ( "unicode" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sns/types" "go.temporal.io/sdk/activity" + + "github.com/PeerDB-io/peer-flow/shared/aws_common" ) type SNSMessageSender interface { @@ -109,15 +110,10 @@ func NewSNSMessageSender(client *sns.Client, config *SNSMessageSenderConfig) SNS } func newSnsClient(ctx context.Context, region *string) (*sns.Client, error) { - sdkConfig, err := config.LoadDefaultConfig(ctx, func(options *config.LoadOptions) error { - if region != nil { - options.Region = *region - } - return nil - }) + sdkConfig, err := aws_common.LoadSdkConfig(ctx, region) if err != nil { return nil, err } - snsClient := sns.NewFromConfig(sdkConfig) + snsClient := sns.NewFromConfig(*sdkConfig) return snsClient, nil } diff --git a/nexus/catalog/migrations/V21__alert_constraint_update.sql b/nexus/catalog/migrations/V21__alert_constraint_update.sql new file mode 100644 index 0000000000..18ddb7d9f8 --- /dev/null +++ b/nexus/catalog/migrations/V21__alert_constraint_update.sql @@ -0,0 +1,6 @@ +ALTER TABLE peerdb_stats.alerting_config +DROP CONSTRAINT alerting_config_service_type_check; + +ALTER TABLE peerdb_stats.alerting_config +ADD CONSTRAINT alerting_config_service_type_check +CHECK (service_type IN ('slack', 'email')); \ No newline at end of file diff --git a/nexus/catalog/migrations/V22__alert_column_add_config_id.sql b/nexus/catalog/migrations/V22__alert_column_add_config_id.sql new file mode 100644 index 0000000000..35c1147bf0 --- /dev/null +++ b/nexus/catalog/migrations/V22__alert_column_add_config_id.sql @@ -0,0 +1,2 @@ +ALTER TABLE peerdb_stats.alerts_v1 +ADD COLUMN alert_config_id BIGINT DEFAULT NULL; diff --git a/ui/app/alert-config/new.tsx b/ui/app/alert-config/new.tsx index 6a399e4c81..12dfbc9cf5 100644 --- a/ui/app/alert-config/new.tsx +++ b/ui/app/alert-config/new.tsx @@ -1,21 +1,27 @@ import { Button } from '@/lib/Button'; import { TextField } from '@/lib/TextField'; import Image from 'next/image'; -import { useState } from 'react'; +import { Dispatch, SetStateAction, useState } from 'react'; import ReactSelect from 'react-select'; import { PulseLoader } from 'react-spinners'; import { ToastContainer, toast } from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; import SelectTheme from '../styles/select'; -import { alertConfigReqSchema, alertConfigType } from './validation'; +import { + alertConfigReqSchema, + alertConfigType, + emailConfigType, + serviceConfigType, + serviceTypeSchemaMap, + slackConfigType, +} from './validation'; + +export type ServiceType = 'slack' | 'email'; export interface AlertConfigProps { id?: bigint; - serviceType: string; - authToken: string; - channelIdString: string; - slotLagGBAlertThreshold: number; - openConnectionsAlertThreshold: number; + serviceType: ServiceType; + alertConfig: serviceConfigType; forEdit?: boolean; } @@ -25,45 +31,152 @@ const notifyErr = (errMsg: string) => { }); }; -function ConfigLabel() { +function ConfigLabel(data: { label: string; value: string }) { return (
Authorisation Token
+Channel IDs
+Email Addresses
+Alert Provider
Authorisation Token
-Channel IDs
-Slot Lag Alert Threshold (in GB)
Open Connections Alert Threshold