From 7c7b5945e853d4faba0d42dff1d024bf793de783 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Tue, 12 Dec 2023 09:16:37 -0500 Subject: [PATCH] simplify --- flow/connectors/postgres/ssh_wrapped_pool.go | 108 ++++++++++--------- 1 file changed, 58 insertions(+), 50 deletions(-) diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index 21722ac7d9..d9fbde5bfd 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -3,9 +3,9 @@ package connpostgres import ( "context" "fmt" - "io" "net" "sync" + "time" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -22,7 +22,6 @@ type SSHWrappedPostgresPool struct { sshServer string once sync.Once sshClient *ssh.Client - localPort uint16 ctx context.Context cancel context.CancelFunc } @@ -47,6 +46,7 @@ func NewSSHWrappedPostgresPool( ) if err != nil { logrus.Error("Failed to get SSH client config: ", err) + cancel() return nil, err } } @@ -70,84 +70,92 @@ func NewSSHWrappedPostgresPool( func (swpp *SSHWrappedPostgresPool) connect() error { var err error swpp.once.Do(func() { - err = swpp.setupSSH(swpp.ctx) + err = swpp.setupSSH() if err != nil { return } + swpp.Pool, err = pgxpool.NewWithConfig(swpp.ctx, swpp.poolConfig) + if err != nil { + logrus.Errorf("Failed to create pool: %v", err) + return + } + + logrus.Infof("Established pool to %s:%d", + swpp.poolConfig.ConnConfig.Host, swpp.poolConfig.ConnConfig.Port) + + err = retryWithBackoff(func() error { + err = swpp.Ping(swpp.ctx) + if err != nil { + logrus.Errorf("Failed to ping pool: %v", err) + return err + } + return nil + }, 5, 5*time.Second) + + if err != nil { + logrus.Errorf("Failed to create pool: %v", err) + } }) return err } -func (swpp *SSHWrappedPostgresPool) setupSSH(ctx context.Context) error { +func (swpp *SSHWrappedPostgresPool) setupSSH() error { if swpp.sshConfig == nil { logrus.Info("SSH config is nil, skipping SSH setup") return nil } logrus.Info("Setting up SSH connection to ", swpp.sshServer) - var err error - // Establish an SSH connection + var err error swpp.sshClient, err = ssh.Dial("tcp", swpp.sshServer, swpp.sshConfig) if err != nil { return err } - logrus.Info("SSH connection established") - // Automatically pick an available local port - localListener, err := net.Listen("tcp", "localhost:0") - if err != nil { - return err - } - addr := localListener.Addr() - swpp.localPort = uint16(addr.(*net.TCPAddr).Port) - - go func() { - defer localListener.Close() - for { - select { - case <-ctx.Done(): - return - default: - localConn, err := localListener.Accept() - if err != nil { - return - } - - remoteConn, err := swpp.sshClient.Dial("tcp", "localhost:5432") - if err != nil { - localConn.Close() - return - } - - go func() { - defer localConn.Close() - defer remoteConn.Close() - select { - case <-ctx.Done(): - return - default: - io.Copy(localConn, remoteConn) - io.Copy(remoteConn, localConn) - } - }() - } - } - }() - // Update the connection string to use the dynamically assigned local port - swpp.poolConfig.ConnConfig.Host = "localhost" - swpp.poolConfig.ConnConfig.Port = swpp.localPort + swpp.poolConfig.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := swpp.sshClient.Dial(network, addr) + if err != nil { + return nil, err + } + return &noDeadlineConn{Conn: conn}, nil + } return nil } func (swpp *SSHWrappedPostgresPool) Close() { swpp.cancel() + if swpp.Pool != nil { swpp.Pool.Close() } + if swpp.sshClient != nil { swpp.sshClient.Close() } } + +type retryFunc func() error + +func retryWithBackoff(fn retryFunc, maxRetries int, backoff time.Duration) (err error) { + for i := 0; i < maxRetries; i++ { + err = fn() + if err == nil { + return nil + } + if i < maxRetries-1 { + logrus.Infof("Attempt #%d failed, retrying in %s", i+1, backoff) + time.Sleep(backoff) + } + } + return err +} + +// see: https://github.com/jackc/pgx/issues/382#issuecomment-1496586216 +type noDeadlineConn struct{ net.Conn } + +func (c *noDeadlineConn) SetDeadline(t time.Time) error { return nil } +func (c *noDeadlineConn) SetReadDeadline(t time.Time) error { return nil } +func (c *noDeadlineConn) SetWriteDeadline(t time.Time) error { return nil }