Skip to content

Commit

Permalink
chore: tests for the pgproxy protocol (#3440)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmakine authored Nov 21, 2024
1 parent 81327f2 commit 604998c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 21 deletions.
52 changes: 31 additions & 21 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,34 +49,37 @@ func (p *PgProxy) Start(ctx context.Context) error {
logger.Errorf(err, "failed to accept connection")
continue
}
go p.handleConnection(ctx, conn)
go HandleConnection(ctx, conn, p.connectionStringFn)
}
}

func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) {
// HandleConnection proxies a single connection.
//
// This should be run as the first thing after accepting a connection.
// It will block until the connection is closed.
func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) {
defer conn.Close()

logger := log.FromContext(ctx)
logger.Infof("new connection established: %s", conn.RemoteAddr())

backend, startup, err := p.connectBackend(ctx, conn)
backend, startup, err := connectBackend(ctx, conn)
if err != nil {
logger.Errorf(err, "failed to connect backend")
return
}
logger.Debugf("startup message: %+v", startup)
logger.Debugf("backend connected: %s", conn.RemoteAddr())

frontend, err := p.connectFrontend(ctx, startup)
dsn, err := connectionFn(ctx, startup.Parameters)
if err != nil {
logger.Errorf(err, "failed to connect frontend")
backend.Send(&pgproto3.ErrorResponse{
Severity: "FATAL",
Message: err.Error(),
})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend error response")
}
handleBackendError(ctx, backend, err)
return
}

frontend, err := connectFrontend(ctx, dsn)
if err != nil {
handleBackendError(ctx, backend, err)
return
}
logger.Debugf("frontend connected")
Expand All @@ -88,15 +91,27 @@ func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) {
return
}

if err := p.proxy(ctx, backend, frontend); err != nil {
if err := proxy(ctx, backend, frontend); err != nil {
logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err)
return
}
logger.Infof("terminating connection to %s", conn.RemoteAddr())
}

func handleBackendError(ctx context.Context, backend *pgproto3.Backend, err error) {
logger := log.FromContext(ctx)
logger.Errorf(err, "backend error")
backend.Send(&pgproto3.ErrorResponse{
Severity: "FATAL",
Message: err.Error(),
})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend error response")
}
}

// connectBackend establishes a connection according to https://www.postgresql.org/docs/current/protocol-flow.html
func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) {
func connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) {
backend := pgproto3.NewBackend(conn, conn)

for {
Expand Down Expand Up @@ -127,12 +142,7 @@ func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Ba
}
}

func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) {
dsn, err := p.connectionStringFn(ctx, startup.Parameters)
if err != nil {
return nil, err
}

func connectFrontend(ctx context.Context, dsn string) (*pgproto3.Frontend, error) {
conn, err := pgconn.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend: %w", err)
Expand All @@ -142,7 +152,7 @@ func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.Startup
return frontend, nil
}

func (p *PgProxy) proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
logger := log.FromContext(ctx)
frontendMessages := make(chan pgproto3.BackendMessage)
backendMessages := make(chan pgproto3.FrontendMessage)
Expand Down
83 changes: 83 additions & 0 deletions internal/pgproxy/pgproxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pgproxy_test

import (
"context"
"net"
"testing"

"github.com/TBD54566975/ftl/internal/dev"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"
)

func TestPgProxy(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
client, proxy := net.Pipe()

dsn, err := dev.SetupPostgres(ctx, "postgres:15.8", 0, false)
assert.NoError(t, err)

frontend := pgproto3.NewFrontend(client, client)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
go pgproxy.HandleConnection(ctx, proxy, func(ctx context.Context, parameters map[string]string) (string, error) {
return dsn, nil
})

t.Run("denies SSL", func(t *testing.T) {
frontend.Send(&pgproto3.SSLRequest{})
assert.NoError(t, frontend.Flush())

assert.Equal(t, readOneByte(t, client), 'N')
})

t.Run("denies GSSEnc", func(t *testing.T) {
frontend.Send(&pgproto3.GSSEncRequest{})
assert.NoError(t, frontend.Flush())

assert.Equal(t, readOneByte(t, client), 'N')
})

t.Run("authenticates with startup message", func(t *testing.T) {
frontend.Send(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{
"user": "ftl",
}})
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.AuthenticationOk](t, frontend)
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})

t.Run("proxies a query to the underlying DB", func(t *testing.T) {
frontend.Send(&pgproto3.Query{String: "SELECT 1"})
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.RowDescription](t, frontend)
assertResponseType[*pgproto3.DataRow](t, frontend)
assertResponseType[*pgproto3.CommandComplete](t, frontend)
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})
}

func readOneByte(t *testing.T, client net.Conn) byte {
t.Helper()

response := make([]byte, 1)
n, err := client.Read(response)
assert.NoError(t, err)
assert.Equal(t, n, 1)
return response[0]
}

func assertResponseType[T any](t *testing.T, f *pgproto3.Frontend) {
t.Helper()

var zero T
resp, err := f.Receive()
assert.NoError(t, err)
_, ok := resp.(T)
assert.True(t, ok, "expected response type %T, got %T", zero, resp)
}

0 comments on commit 604998c

Please sign in to comment.