diff --git a/internal/pgproxy/pgproxy.go b/internal/pgproxy/pgproxy.go index 960326656a..13008a79e8 100644 --- a/internal/pgproxy/pgproxy.go +++ b/internal/pgproxy/pgproxy.go @@ -49,17 +49,21 @@ func (p *PgProxy) Start(ctx context.Context) error { logger.Errorf(err, "failed to accept connection") continue } - go p.handleConnection(ctx, conn) + go HandleConnection(ctx, conn, p.connectionStringFn) } } -func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) { +// HandleConnection proxies a single connection. +// +// This should be run as the first thing after accepting a connection. +// It will block until the connection is closed. +func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) { defer conn.Close() logger := log.FromContext(ctx) logger.Infof("new connection established: %s", conn.RemoteAddr()) - backend, startup, err := p.connectBackend(ctx, conn) + backend, startup, err := connectBackend(ctx, conn) if err != nil { logger.Errorf(err, "failed to connect backend") return @@ -67,16 +71,15 @@ func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) { logger.Debugf("startup message: %+v", startup) logger.Debugf("backend connected: %s", conn.RemoteAddr()) - frontend, err := p.connectFrontend(ctx, startup) + dsn, err := connectionFn(ctx, startup.Parameters) if err != nil { - logger.Errorf(err, "failed to connect frontend") - backend.Send(&pgproto3.ErrorResponse{ - Severity: "FATAL", - Message: err.Error(), - }) - if err := backend.Flush(); err != nil { - logger.Errorf(err, "failed to flush backend error response") - } + handleBackendError(ctx, backend, err) + return + } + + frontend, err := connectFrontend(ctx, dsn) + if err != nil { + handleBackendError(ctx, backend, err) return } logger.Debugf("frontend connected") @@ -88,15 +91,27 @@ func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) { return } - if err := p.proxy(ctx, backend, frontend); err != nil { + if err := proxy(ctx, backend, frontend); err != nil { logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err) return } logger.Infof("terminating connection to %s", conn.RemoteAddr()) } +func handleBackendError(ctx context.Context, backend *pgproto3.Backend, err error) { + logger := log.FromContext(ctx) + logger.Errorf(err, "backend error") + backend.Send(&pgproto3.ErrorResponse{ + Severity: "FATAL", + Message: err.Error(), + }) + if err := backend.Flush(); err != nil { + logger.Errorf(err, "failed to flush backend error response") + } +} + // connectBackend establishes a connection according to https://www.postgresql.org/docs/current/protocol-flow.html -func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) { +func connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) { backend := pgproto3.NewBackend(conn, conn) for { @@ -127,12 +142,7 @@ func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Ba } } -func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) { - dsn, err := p.connectionStringFn(ctx, startup.Parameters) - if err != nil { - return nil, err - } - +func connectFrontend(ctx context.Context, dsn string) (*pgproto3.Frontend, error) { conn, err := pgconn.Connect(ctx, dsn) if err != nil { return nil, fmt.Errorf("failed to connect to backend: %w", err) @@ -142,7 +152,7 @@ func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.Startup return frontend, nil } -func (p *PgProxy) proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error { +func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error { logger := log.FromContext(ctx) frontendMessages := make(chan pgproto3.BackendMessage) backendMessages := make(chan pgproto3.FrontendMessage) diff --git a/internal/pgproxy/pgproxy_test.go b/internal/pgproxy/pgproxy_test.go new file mode 100644 index 0000000000..2ef6633e5a --- /dev/null +++ b/internal/pgproxy/pgproxy_test.go @@ -0,0 +1,83 @@ +package pgproxy_test + +import ( + "context" + "net" + "testing" + + "github.com/TBD54566975/ftl/internal/dev" + "github.com/TBD54566975/ftl/internal/log" + "github.com/TBD54566975/ftl/internal/pgproxy" + "github.com/alecthomas/assert/v2" + "github.com/jackc/pgx/v5/pgproto3" +) + +func TestPgProxy(t *testing.T) { + ctx := log.ContextWithNewDefaultLogger(context.Background()) + client, proxy := net.Pipe() + + dsn, err := dev.SetupPostgres(ctx, "postgres:15.8", 0, false) + assert.NoError(t, err) + + frontend := pgproto3.NewFrontend(client, client) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go pgproxy.HandleConnection(ctx, proxy, func(ctx context.Context, parameters map[string]string) (string, error) { + return dsn, nil + }) + + t.Run("denies SSL", func(t *testing.T) { + frontend.Send(&pgproto3.SSLRequest{}) + assert.NoError(t, frontend.Flush()) + + assert.Equal(t, readOneByte(t, client), 'N') + }) + + t.Run("denies GSSEnc", func(t *testing.T) { + frontend.Send(&pgproto3.GSSEncRequest{}) + assert.NoError(t, frontend.Flush()) + + assert.Equal(t, readOneByte(t, client), 'N') + }) + + t.Run("authenticates with startup message", func(t *testing.T) { + frontend.Send(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{ + "user": "ftl", + }}) + assert.NoError(t, frontend.Flush()) + + assertResponseType[*pgproto3.AuthenticationOk](t, frontend) + assertResponseType[*pgproto3.ReadyForQuery](t, frontend) + }) + + t.Run("proxies a query to the underlying DB", func(t *testing.T) { + frontend.Send(&pgproto3.Query{String: "SELECT 1"}) + assert.NoError(t, frontend.Flush()) + + assertResponseType[*pgproto3.RowDescription](t, frontend) + assertResponseType[*pgproto3.DataRow](t, frontend) + assertResponseType[*pgproto3.CommandComplete](t, frontend) + assertResponseType[*pgproto3.ReadyForQuery](t, frontend) + }) +} + +func readOneByte(t *testing.T, client net.Conn) byte { + t.Helper() + + response := make([]byte, 1) + n, err := client.Read(response) + assert.NoError(t, err) + assert.Equal(t, n, 1) + return response[0] +} + +func assertResponseType[T any](t *testing.T, f *pgproto3.Frontend) { + t.Helper() + + var zero T + resp, err := f.Receive() + assert.NoError(t, err) + _, ok := resp.(T) + assert.True(t, ok, "expected response type %T, got %T", zero, resp) +}