Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the ability to connect to Postgres via an SSH tunnel #800

Merged
merged 10 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -146,36 +147,36 @@ func (h *FlowRequestHandler) GetColumns(

defer peerPool.Close()
rows, err := peerPool.Query(ctx, `
SELECT
SELECT
cols.column_name,
cols.data_type,
CASE
CASE
WHEN constraint_type = 'PRIMARY KEY' THEN true
ELSE false
END AS is_primary_key
FROM
FROM
information_schema.columns cols
LEFT JOIN
LEFT JOIN
(
SELECT
SELECT
kcu.column_name,
tc.constraint_type
FROM
FROM
information_schema.key_column_usage kcu
JOIN
JOIN
information_schema.table_constraints tc
ON
ON
kcu.constraint_name = tc.constraint_name
AND kcu.constraint_schema = tc.constraint_schema
AND kcu.constraint_name = tc.constraint_name
WHERE
WHERE
tc.constraint_type = 'PRIMARY KEY'
AND kcu.table_schema = $1
AND kcu.table_name = $2
) AS pk
ON
ON
cols.column_name = pk.column_name
WHERE
WHERE
cols.table_schema = $3
AND cols.table_name = $4;
`, req.SchemaName, req.TableName, req.SchemaName, req.TableName)
Expand Down Expand Up @@ -210,14 +211,17 @@ func (h *FlowRequestHandler) GetSlotInfo(

pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig)
if err != nil {
logrus.Errorf("Failed to create postgres connector: %v", err)
return &protos.PeerSlotResponse{SlotData: nil}, err
}
defer pgConnector.Close()

slotInfo, err := pgConnector.GetSlotInfo("")
if err != nil {
logrus.Errorf("Failed to get slot info: %v", err)
return &protos.PeerSlotResponse{SlotData: nil}, err
}

return &protos.PeerSlotResponse{
SlotData: slotInfo,
}, nil
Expand All @@ -227,16 +231,27 @@ func (h *FlowRequestHandler) GetStatInfo(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.PeerStatResponse, error) {
peerPool, peerUser, err := h.getPoolForPGPeer(ctx, req.PeerName)
pgConfig, err := h.getPGPeerConfig(ctx, req.PeerName)
if err != nil {
return &protos.PeerStatResponse{StatData: nil}, err
}
defer peerPool.Close()

pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig)
if err != nil {
logrus.Errorf("Failed to create postgres connector: %v", err)
return &protos.PeerStatResponse{StatData: nil}, err
}
defer pgConnector.Close()

peerPool := pgConnector.GetPool()
peerUser := pgConfig.User

rows, err := peerPool.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+
"EXTRACT(epoch FROM(now()-query_start)) AS dur"+
" FROM pg_stat_activity WHERE "+
"usename=$1 AND state != 'idle';", peerUser)
if err != nil {
logrus.Errorf("Failed to get stat info: %v", err)
return &protos.PeerStatResponse{StatData: nil}, err
}
defer rows.Close()
Expand All @@ -251,6 +266,7 @@ func (h *FlowRequestHandler) GetStatInfo(

err := rows.Scan(&pid, &waitEvent, &waitEventType, &queryStart, &query, &duration)
if err != nil {
logrus.Errorf("Failed to scan row: %v", err)
return &protos.PeerStatResponse{StatData: nil}, err
}

Expand Down Expand Up @@ -288,6 +304,7 @@ func (h *FlowRequestHandler) GetStatInfo(
Duration: float32(d),
})
}

return &protos.PeerStatResponse{
StatData: statInfoRows,
}, nil
Expand Down
17 changes: 11 additions & 6 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ type PostgresConnector struct {
connStr string
ctx context.Context
config *protos.PostgresConfig
pool *pgxpool.Pool
replPool *pgxpool.Pool
pool *SSHWrappedPostgresPool
replPool *SSHWrappedPostgresPool
tableSchemaMapping map[string]*protos.TableSchema
customTypesMapping map[uint32]string
metadataSchema string
Expand All @@ -51,12 +51,12 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// set pool size to 3 to avoid connection pool exhaustion
connConfig.MaxConns = 3

pool, err := pgxpool.NewWithConfig(ctx, connConfig)
pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}

customTypeMap, err := utils.GetCustomDataTypes(ctx, pool)
customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool)
if err != nil {
return nil, fmt.Errorf("failed to get custom type map: %w", err)
}
Expand All @@ -73,7 +73,7 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)

// TODO: replPool not initializing might be intentional, if we only want to use QRep mirrors
// and the user doesn't have the REPLICATION permission
replPool, err := pgxpool.NewWithConfig(ctx, replConnConfig)
replPool, err := NewSSHWrappedPostgresPool(ctx, replConnConfig, pgConfig.SshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
Expand All @@ -94,6 +94,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
}, nil
}

// GetPool returns the connection pool.
func (c *PostgresConnector) GetPool() *SSHWrappedPostgresPool {
return c.pool
}

// Close closes all connections.
func (c *PostgresConnector) Close() error {
if c.pool != nil {
Expand Down Expand Up @@ -230,7 +235,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) error {

cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{
AppContext: c.ctx,
Connection: c.replPool,
Connection: c.replPool.Pool,
SrcTableIDNameMapping: req.SrcTableIDNameMapping,
Slot: slotName,
Publication: publicationName,
Expand Down
15 changes: 10 additions & 5 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ func (c *PostgresConnector) PullQRepRecords(
log.WithFields(log.Fields{
"partitionId": partition.PartitionId,
}).Infof("pulling full table partition for flow job %s", config.FlowJobName)
executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand Down Expand Up @@ -361,7 +362,8 @@ func (c *PostgresConnector) PullQRepRecords(
return nil, err
}

executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand All @@ -386,7 +388,8 @@ func (c *PostgresConnector) PullQRepRecordStream(
"flowName": config.FlowJobName,
"partitionId": partition.PartitionId,
}).Infof("pulling full table partition for flow job %s", config.FlowJobName)
executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -434,7 +437,8 @@ func (c *PostgresConnector) PullQRepRecordStream(
return 0, err
}

executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -558,7 +562,8 @@ func (c *PostgresConnector) PullXminRecordStream(
query += " WHERE age(xmin) > 0 AND age(xmin) <= age($1::xid)"
}

executor, err := NewQRepQueryExecutorSnapshot(c.pool, c.ctx, c.config.TransactionSnapshot,
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, currentSnapshotXmin, err
Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/postgres/qrep_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ func TestGetQRepPartitions(t *testing.T) {
t.Fatalf("Failed to parse config: %v", err)
}

pool, err := pgxpool.NewWithConfig(context.Background(), config)
pool, err := NewSSHWrappedPostgresPool(context.Background(), config, nil)
if err != nil {
t.Fatalf("unable to connect to database: %v", err)
t.Fatalf("Failed to create pool: %v", err)
}

// Generate a random schema name
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestGetQRepPartitions(t *testing.T) {
}

// from 2010 Jan 1 10:00 AM UTC to 2010 Jan 30 10:00 AM UTC
numRows := prepareTestData(t, pool, schemaName)
numRows := prepareTestData(t, pool.Pool, schemaName)

secondsInADay := uint32(24 * time.Hour / time.Second)
fmt.Printf("secondsInADay: %d\n", secondsInADay)
Expand Down
Loading
Loading