Skip to content

Commit

Permalink
Query layer: support ssh fields, sync interval, simplify code path (#…
Browse files Browse the repository at this point in the history
…1704)

- Adds parsing for SSHConfig fields in Create Peer command for Postgres
```sql
CREATE PEER postgres_peer FROM POSTGRES WITH
(
    host = '<hostname>',
    port = 5432,
    user = '<user>',
    password = '<password>',
    database = '<dbname>',
    ssh_config = '{
        "host": "<ssh_host>",
        "port": 22,
        "user": "<ssh_user>",
        "password": "<ssh_password>",
        "private_key": "<ssh_private>"
    }'
);
```

- Also you can now specify `sync_interval` in the create mirror command

- Remove catalog flows table entry and update operations in query layer
side, as we are anyways hitting a grpc endpoint which takes care of
that. This makes the create_catalog_entry flag in the request to
create/validate mirror endpoints redundant, so that is removed from
route.proto

- Fixes a bug in validate mirror where select * from sourcetables limit
0 was not checked for empty publication

Functionally tested via query layer and UI
  • Loading branch information
Amogh-Bharadwaj authored May 9, 2024
1 parent a2e9ab8 commit e2be383
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 130 deletions.
12 changes: 5 additions & 7 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,13 @@ func (h *FlowRequestHandler) CreateCDCFlow(
req.ConnectionConfigs.SyncedAtColName = strings.ToUpper(req.ConnectionConfigs.SyncedAtColName)
}

if req.CreateCatalogEntry {
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}

err := h.updateFlowConfigInCatalog(ctx, cfg)
err = h.updateFlowConfigInCatalog(ctx, cfg)
if err != nil {
slog.Error("unable to update flow config in catalog", slog.Any("error", err))
return nil, fmt.Errorf("unable to update flow config in catalog: %w", err)
Expand Down
28 changes: 15 additions & 13 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
if req.CreateCatalogEntry && !req.ConnectionConfigs.Resync {
if !req.ConnectionConfigs.Resync {
mirrorExists, existCheckErr := h.CheckIfMirrorNameExists(ctx, req.ConnectionConfigs.FlowJobName)
if existCheckErr != nil {
slog.Error("/validatecdc failed to check if mirror name exists", slog.Any("error", existCheckErr))
Expand Down Expand Up @@ -46,7 +46,9 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, errors.New("source peer config is nil")
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, errors.New("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
Expand Down Expand Up @@ -103,17 +105,17 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

pubName := req.ConnectionConfigs.PublicationName
if pubName != "" {
err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
slog.Error(displayErr.Error())
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// create a separate connection pool for non-replication queries as replication connections cannot
// be used for extended query protocol, i.e. prepared statements
connConfig, err := pgx.ParseConfig(connectionString)
replConfig := connConfig.Copy()
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}

replConfig := connConfig.Copy()
runtimeParams := connConfig.Config.RuntimeParams
runtimeParams["idle_in_transaction_session_timeout"] = "0"
runtimeParams["statement_timeout"] = "0"
Expand Down
33 changes: 18 additions & 15 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,31 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
}

tableStr := strings.Join(tableArr, ",")
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)

if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
}

return nil
Expand Down
27 changes: 24 additions & 3 deletions nexus/analyzer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use pt::{
peerdb_peers::{
peer::Config, BigqueryConfig, ClickhouseConfig, DbType, EventHubConfig, GcpServiceAccount,
KafkaConfig, MongoConfig, Peer, PostgresConfig, PubSubConfig, S3Config, SnowflakeConfig,
SqlServerConfig,
SqlServerConfig, SshConfig,
},
};
use qrep::process_options;
Expand Down Expand Up @@ -300,6 +300,11 @@ impl StatementAnalyzer for PeerDDLAnalyzer {
_ => None,
};

let sync_interval: Option<u64> = match raw_options.remove("sync_interval") {
Some(Expr::Value(ast::Value::Number(n, _))) => Some(n.parse::<u64>()?),
_ => None,
};

let soft_delete_col_name: Option<String> = match raw_options
.remove("soft_delete_col_name")
{
Expand Down Expand Up @@ -347,6 +352,7 @@ impl StatementAnalyzer for PeerDDLAnalyzer {
push_batch_size,
push_parallelism,
max_batch_size,
sync_interval,
resync,
soft_delete_col_name,
synced_at_col_name,
Expand Down Expand Up @@ -646,6 +652,19 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
Config::MongoConfig(mongo_config)
}
DbType::Postgres => {
let ssh_fields: Option<SshConfig> = match opts.get("ssh_config") {
Some(ssh_config) => {
let ssh_config_str = ssh_config.to_string();
if ssh_config_str.is_empty() {
None
} else {
serde_json::from_str(&ssh_config_str)
.context("failed to deserialize ssh_config")?
}
}
None => None,
};

let postgres_config = PostgresConfig {
host: opts.get("host").context("no host specified")?.to_string(),
port: opts
Expand All @@ -667,8 +686,9 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
.to_string(),
metadata_schema: opts.get("metadata_schema").map(|s| s.to_string()),
transaction_snapshot: "".to_string(),
ssh_config: None,
ssh_config: ssh_fields,
};

Config::PostgresConfig(postgres_config)
}
DbType::S3 => {
Expand Down Expand Up @@ -744,7 +764,8 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
.unwrap_or_default(),
disable_tls: opts
.get("disable_tls")
.map(|s| s.parse::<bool>().unwrap_or_default()).unwrap_or_default(),
.map(|s| s.parse::<bool>().unwrap_or_default())
.unwrap_or_default(),
endpoint: opts.get("endpoint").map(|s| s.to_string()),
};
Config::ClickhouseConfig(clickhouse_config)
Expand Down
5 changes: 4 additions & 1 deletion nexus/analyzer/src/qrep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ pub fn process_options(

// If mode is upsert, we need unique key columns
if opts.get("mode") == Some(&Value::String(String::from("upsert")))
&& opts.get("unique_key_columns").map(|ukc| ukc == &Value::Array(Vec::new())).unwrap_or(true)
&& opts
.get("unique_key_columns")
.map(|ukc| ukc == &Value::Array(Vec::new()))
.unwrap_or(true)
{
anyhow::bail!("For upsert mode, unique_key_columns must be specified");
}
Expand Down
66 changes: 1 addition & 65 deletions nexus/catalog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use peer_postgres::{self, ast};
use pgwire::error::PgWireResult;
use postgres_connection::{connect_postgres, get_pg_connection_string};
use pt::{
flow_model::{FlowJob, QRepFlowJob},
flow_model::QRepFlowJob,
peerdb_peers::PostgresConfig,
peerdb_peers::{peer::Config, DbType, Peer},
prost::Message,
Expand Down Expand Up @@ -332,70 +332,6 @@ impl Catalog {
})
}

async fn normalize_schema_for_table_identifier(
&self,
table_identifier: &str,
peer_id: i32,
) -> anyhow::Result<String> {
let peer_dbtype = self.get_peer_type_for_id(peer_id).await?;

if !table_identifier.contains('.') && peer_dbtype != DbType::Bigquery {
Ok(format!("public.{}", table_identifier))
} else {
Ok(String::from(table_identifier))
}
}

pub async fn create_cdc_flow_job_entry(&self, job: &FlowJob) -> anyhow::Result<()> {
let source_peer_id = self
.get_peer_id_i32(&job.source_peer)
.await
.context("unable to get source peer id")?;
let destination_peer_id = self
.get_peer_id_i32(&job.target_peer)
.await
.context("unable to get destination peer id")?;

let stmt = self
.pg
.prepare_typed(
"INSERT INTO flows (name, source_peer, destination_peer, description,
source_table_identifier, destination_table_identifier) VALUES ($1, $2, $3, $4, $5, $6)",
&[types::Type::TEXT, types::Type::INT4, types::Type::INT4, types::Type::TEXT,
types::Type::TEXT, types::Type::TEXT],
)
.await?;

for table_mapping in &job.table_mappings {
let _rows = self
.pg
.execute(
&stmt,
&[
&job.name,
&source_peer_id,
&destination_peer_id,
&job.description,
&self
.normalize_schema_for_table_identifier(
&table_mapping.source_table_identifier,
source_peer_id,
)
.await?,
&self
.normalize_schema_for_table_identifier(
&table_mapping.destination_table_identifier,
destination_peer_id,
)
.await?,
],
)
.await?;
}

Ok(())
}

pub async fn get_qrep_flow_job_by_name(
&self,
job_name: &str,
Expand Down
3 changes: 1 addition & 2 deletions nexus/flow-rs/src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ impl FlowGrpcClient {
) -> anyhow::Result<String> {
let create_peer_flow_req = pt::peerdb_route::CreateCdcFlowRequest {
connection_configs: Some(peer_flow_config),
create_catalog_entry: false,
};
let response = self.client.create_cdc_flow(create_peer_flow_req).await?;
let workflow_id = response.into_inner().workflow_id;
Expand Down Expand Up @@ -176,7 +175,7 @@ impl FlowGrpcClient {
initial_snapshot_only: job.initial_snapshot_only,
script: job.script.clone(),
system: system as i32,
..Default::default()
idle_timeout_seconds: job.sync_interval.unwrap_or_default(),
};

self.start_peer_flow(flow_conn_cfg).await
Expand Down
1 change: 1 addition & 0 deletions nexus/pt/src/flow_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct FlowJob {
pub push_parallelism: Option<i64>,
pub push_batch_size: Option<i64>,
pub max_batch_size: Option<u32>,
pub sync_interval: Option<u64>,
pub resync: bool,
pub soft_delete_col_name: Option<String>,
pub synced_at_col_name: Option<String>,
Expand Down
22 changes: 2 additions & 20 deletions nexus/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,6 @@ impl NexusBackend {
}
}

self.catalog
.create_cdc_flow_job_entry(flow_job)
.await
.map_err(|err| {
PgWireError::ApiError(
format!("unable to create mirror job entry: {:?}", err).into(),
)
})?;

// get source and destination peers
let (src_peer, dst_peer) = join!(
Self::get_peer_of_mirror(self.catalog.as_ref(), &flow_job.source_peer),
Expand All @@ -432,21 +423,12 @@ impl NexusBackend {

// make a request to the flow service to start the job.
let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await;
let workflow_id = flow_handler
flow_handler
.start_peer_flow_job(flow_job, src_peer, dst_peer)
.await
.map_err(|err| {
PgWireError::ApiError(
format!("unable to submit job: {:?}", err).into(),
)
})?;

self.catalog
.update_workflow_id_for_flow_job(&flow_job.name, &workflow_id)
.await
.map_err(|err| {
PgWireError::ApiError(
format!("unable to save job metadata: {:?}", err).into(),
format!("unable to submit job: {:?}", err.to_string()).into(),
)
})?;

Expand Down
1 change: 0 additions & 1 deletion protos/route.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ package peerdb_route;

message CreateCDCFlowRequest {
peerdb_flow.FlowConnectionConfigs connection_configs = 1;
bool create_catalog_entry = 2;
}

message CreateCDCFlowResponse {
Expand Down
1 change: 0 additions & 1 deletion ui/app/api/mirrors/cdc/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ export async function POST(request: Request) {
const flowServiceAddr = GetFlowHttpAddressFromEnv();
const req: CreateCDCFlowRequest = {
connectionConfigs: config,
createCatalogEntry: true,
};
try {
const createStatus: CreateCDCFlowResponse = await fetch(
Expand Down
1 change: 0 additions & 1 deletion ui/app/api/mirrors/cdc/validate/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export async function POST(request: NextRequest) {
const flowServiceAddr = GetFlowHttpAddressFromEnv();
const req: CreateCDCFlowRequest = {
connectionConfigs: config,
createCatalogEntry: false,
};
try {
const validateResponse: ValidateCDCMirrorResponse = await fetch(
Expand Down

0 comments on commit e2be383

Please sign in to comment.