diff --git a/internal/datastore/postgres/postgres_test.go b/internal/datastore/postgres/postgres_test.go index a80f9f3072..52b4417542 100644 --- a/internal/datastore/postgres/postgres_test.go +++ b/internal/datastore/postgres/postgres_test.go @@ -15,6 +15,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/samber/lo" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/sdk/trace" @@ -43,10 +44,25 @@ func (pgd *pgDatastore) ExampleRetryableError() error { } } +type postgresConfig struct { + targetMigration string + migrationPhase string + pgVersion string + pgbouncer bool +} + // the global OTel tracer is used everywhere, so we synchronize tests over a global test tracer var ( otelMutex = sync.Mutex{} testTraceProvider *trace.TracerProvider + postgresConfigs = lo.FlatMap( + []string{pgversion.MinimumSupportedPostgresVersion, "14", "15", "16"}, + func(postgresVersion string, _ int) []postgresConfig { + return lo.Map([]bool{false, true}, func(enablePgbouncer bool, _ int) postgresConfig { + return postgresConfig{"head", "", postgresVersion, enablePgbouncer} + }) + }, + ) ) func init() { @@ -59,20 +75,14 @@ func init() { func TestPostgresDatastore(t *testing.T) { t.Parallel() - for _, config := range []struct { - targetMigration string - migrationPhase string - pgVersion string - }{ - {"head", "", pgversion.MinimumSupportedPostgresVersion}, - {"head", "", "14"}, - {"head", "", "15"}, - {"head", "", "16"}, - } { - config := config - t.Run(fmt.Sprintf("postgres-%s-%s-%s", config.pgVersion, config.targetMigration, config.migrationPhase), func(t *testing.T) { + for _, config := range postgresConfigs { + pgbouncerStr := "" + if config.pgbouncer { + pgbouncerStr = "pgbouncer-" + } + t.Run(fmt.Sprintf("%spostgres-%s-%s-%s", pgbouncerStr, config.pgVersion, config.targetMigration, config.migrationPhase), func(t *testing.T) { t.Parallel() - b := testdatastore.RunPostgresForTesting(t, "", config.targetMigration, config.pgVersion) + b := testdatastore.RunPostgresForTesting(t, "", config.targetMigration, config.pgVersion, config.pgbouncer) test.All(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { @@ -191,12 +201,13 @@ func TestPostgresDatastore(t *testing.T) { func TestPostgresDatastoreWithoutCommitTimestamps(t *testing.T) { t.Parallel() - for _, pgVersion := range []string{pgversion.MinimumSupportedPostgresVersion, "14", "15", "16"} { - pgVersion := pgVersion + for _, config := range postgresConfigs { + pgVersion := config.pgVersion + enablePgbouncer := config.pgbouncer t.Run(fmt.Sprintf("postgres-%s", pgVersion), func(t *testing.T) { t.Parallel() - b := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", "head", false, pgVersion) + b := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", "head", false, pgVersion, enablePgbouncer) // NOTE: watch API requires the commit timestamps, so we skip those tests here. test.AllWithExceptions(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { @@ -1130,7 +1141,7 @@ func OTelTracingTest(t *testing.T, ds datastore.Datastore) { func WatchNotEnabledTest(t *testing.T, _ testdatastore.RunningEngineForTest, pgVersion string) { require := require.New(t) - ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion).NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore { ds, err := newPostgresDatastore(uri, RevisionQuantization(0), GCWindow(time.Millisecond*1), @@ -1154,7 +1165,7 @@ func WatchNotEnabledTest(t *testing.T, _ testdatastore.RunningEngineForTest, pgV func BenchmarkPostgresQuery(b *testing.B) { req := require.New(b) - ds := testdatastore.RunPostgresForTesting(b, "", migrate.Head, pgversion.MinimumSupportedPostgresVersion).NewDatastore(b, func(engine, uri string) datastore.Datastore { + ds := testdatastore.RunPostgresForTesting(b, "", migrate.Head, pgversion.MinimumSupportedPostgresVersion, false).NewDatastore(b, func(engine, uri string) datastore.Datastore { ds, err := newPostgresDatastore(uri, RevisionQuantization(0), GCWindow(time.Millisecond*1), @@ -1188,7 +1199,7 @@ func BenchmarkPostgresQuery(b *testing.B) { func datastoreWithInterceptorAndTestData(t *testing.T, interceptor pgcommon.QueryInterceptor, pgVersion string) datastore.Datastore { require := require.New(t) - ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion).NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore { ds, err := newPostgresDatastore(uri, RevisionQuantization(0), GCWindow(time.Millisecond*1), diff --git a/internal/testserver/datastore/datastore.go b/internal/testserver/datastore/datastore.go index 856e611179..3f9a65cba9 100644 --- a/internal/testserver/datastore/datastore.go +++ b/internal/testserver/datastore/datastore.go @@ -63,7 +63,7 @@ func RunDatastoreEngineWithBridge(t testing.TB, engine string, bridgeNetworkName case "cockroachdb": return RunCRDBForTesting(t, bridgeNetworkName) case "postgres": - return RunPostgresForTesting(t, bridgeNetworkName, migrate.Head, version.MinimumSupportedPostgresVersion) + return RunPostgresForTesting(t, bridgeNetworkName, migrate.Head, version.MinimumSupportedPostgresVersion, false) case "mysql": return RunMySQLForTesting(t, bridgeNetworkName) case "spanner": diff --git a/internal/testserver/datastore/postgres.go b/internal/testserver/datastore/postgres.go index 64001c70a8..2257413c75 100644 --- a/internal/testserver/datastore/postgres.go +++ b/internal/testserver/datastore/postgres.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/require" pgmigrations "github.com/authzed/spicedb/internal/datastore/postgres/migrations" @@ -19,69 +20,91 @@ import ( "github.com/authzed/spicedb/pkg/secrets" ) +const ( + POSTGRES_TEST_USER = "postgres" + POSTGRES_TEST_PASSWORD = "secret" + POSTGRES_TEST_PORT = "5432" + POSTGRES_TEST_MAX_CONNECTIONS = "500" + PGBOUNCER_TEST_PORT = "6432" +) + +type container struct { + hostHostname string + hostPort string + containerHostname string + containerPort string +} + type postgresTester struct { - conn *pgx.Conn - hostname string - port string - creds string - targetMigration string + container + hostConn *pgx.Conn + creds string + targetMigration string + pgbouncerProxy *container + useContainerHostname bool } // RunPostgresForTesting returns a RunningEngineForTest for postgres -func RunPostgresForTesting(t testing.TB, bridgeNetworkName string, targetMigration string, pgVersion string) RunningEngineForTest { - return RunPostgresForTestingWithCommitTimestamps(t, bridgeNetworkName, targetMigration, true, pgVersion) +func RunPostgresForTesting(t testing.TB, bridgeNetworkName string, targetMigration string, pgVersion string, enablePgbouncer bool) RunningEngineForTest { + return RunPostgresForTestingWithCommitTimestamps(t, bridgeNetworkName, targetMigration, true, pgVersion, enablePgbouncer) } -func RunPostgresForTestingWithCommitTimestamps(t testing.TB, bridgeNetworkName string, targetMigration string, withCommitTimestamps bool, pgVersion string) RunningEngineForTest { +func RunPostgresForTestingWithCommitTimestamps(t testing.TB, bridgeNetworkName string, targetMigration string, withCommitTimestamps bool, pgVersion string, enablePgbouncer bool) RunningEngineForTest { pool, err := dockertest.NewPool("") require.NoError(t, err) - name := fmt.Sprintf("postgres-%s", uuid.New().String()) + bridgeSupplied := bridgeNetworkName != "" + if enablePgbouncer && !bridgeSupplied { + // We will need a network bridge if we're running pgbouncer + bridgeNetworkName = createNetworkBridge(t, pool) + } + + postgresContainerHostname := fmt.Sprintf("postgres-%s", uuid.New().String()) - cmd := []string{"-N", "500"} // Max Connections + cmd := []string{"-N", POSTGRES_TEST_MAX_CONNECTIONS} if withCommitTimestamps { cmd = append(cmd, "-c", "track_commit_timestamp=1") } - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Name: name, - Repository: "postgres", - Tag: pgVersion, - Env: []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=defaultdb"}, - ExposedPorts: []string{"5432/tcp"}, + postgres, err := pool.RunWithOptions(&dockertest.RunOptions{ + Name: postgresContainerHostname, + Repository: "postgres", + Tag: pgVersion, + Env: []string{ + "POSTGRES_USER=" + POSTGRES_TEST_USER, + "POSTGRES_PASSWORD=" + POSTGRES_TEST_PASSWORD, + // use md5 auth to align postgres and pgbouncer auth methods + "POSTGRES_HOST_AUTH_METHOD=md5", + "POSTGRES_INITDB_ARGS=--auth=md5", + }, + ExposedPorts: []string{POSTGRES_TEST_PORT + "/tcp"}, NetworkID: bridgeNetworkName, Cmd: cmd, }) require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, pool.Purge(postgres)) + }) builder := &postgresTester{ - hostname: "localhost", - creds: "postgres:secret", - targetMigration: targetMigration, + container: container{ + hostHostname: "localhost", + hostPort: postgres.GetPort(POSTGRES_TEST_PORT + "/tcp"), + containerHostname: postgresContainerHostname, + containerPort: POSTGRES_TEST_PORT, + }, + creds: POSTGRES_TEST_USER + ":" + POSTGRES_TEST_PASSWORD, + targetMigration: targetMigration, + useContainerHostname: bridgeSupplied, } - t.Cleanup(func() { - require.NoError(t, pool.Purge(resource)) - }) - port := resource.GetPort(fmt.Sprintf("%d/tcp", 5432)) - if bridgeNetworkName != "" { - builder.hostname = name - builder.port = "5432" - } else { - builder.port = port + if enablePgbouncer { + // if we are running with pgbouncer enabled then set it up + builder.runPgbouncerForTesting(t, pool, bridgeNetworkName) } - uri := fmt.Sprintf("postgres://%s@localhost:%s/defaultdb?sslmode=disable", builder.creds, port) - require.NoError(t, pool.Retry(func() error { - var err error - ctx, cancelConnect := context.WithTimeout(context.Background(), dockerBootTimeout) - defer cancelConnect() - builder.conn, err = pgx.Connect(ctx, uri) - if err != nil { - return err - } - return nil - })) + builder.hostConn = builder.initializeHostConnection(t, pool) + return builder } @@ -91,14 +114,15 @@ func (b *postgresTester) NewDatabase(t testing.TB) string { newDBName := "db" + uniquePortion - _, err = b.conn.Exec(context.Background(), "CREATE DATABASE "+newDBName) + _, err = b.hostConn.Exec(context.Background(), "CREATE DATABASE "+newDBName) require.NoError(t, err) + hostname, port := b.getHostnameAndPort() return fmt.Sprintf( "postgres://%s@%s:%s/%s?sslmode=disable", b.creds, - b.hostname, - b.port, + hostname, + port, newDBName, ) } @@ -113,3 +137,94 @@ func (b *postgresTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore return initFunc("postgres", connectStr) } + +func createNetworkBridge(t testing.TB, pool *dockertest.Pool) string { + bridgeNetworkName := fmt.Sprintf("bridge-%s", uuid.New().String()) + network, err := pool.Client.CreateNetwork(docker.CreateNetworkOptions{Name: bridgeNetworkName}) + + require.NoError(t, err) + t.Cleanup(func() { + pool.Client.RemoveNetwork(network.ID) + }) + + return bridgeNetworkName +} + +func (b *postgresTester) runPgbouncerForTesting(t testing.TB, pool *dockertest.Pool, bridgeNetworkName string) { + uniqueID := uuid.New().String() + pgbouncerContainerHostname := fmt.Sprintf("pgbouncer-%s", uniqueID) + + pgbouncer, err := pool.RunWithOptions(&dockertest.RunOptions{ + Name: pgbouncerContainerHostname, + Repository: "edoburu/pgbouncer", + Tag: "latest", + Env: []string{ + "DB_USER=" + POSTGRES_TEST_USER, + "DB_PASSWORD=" + POSTGRES_TEST_PASSWORD, + "DB_HOST=" + b.containerHostname, + "DB_PORT=" + b.containerPort, + "LISTEN_PORT=" + PGBOUNCER_TEST_PORT, + "DB_NAME=*", // Needed to make pgbouncer okay with the randomly named databases generated by the test suite + "AUTH_TYPE=md5", // use the same auth type as postgres + "MAX_CLIENT_CONN=" + POSTGRES_TEST_MAX_CONNECTIONS, + // params needed for spicedb + "POOL_MODE=session", // https://github.com/authzed/spicedb/issues/1217 + "IGNORE_STARTUP_PARAMETERS=plan_cache_mode", // Tell pgbouncer to pass this param thru to postgres. + }, + ExposedPorts: []string{PGBOUNCER_TEST_PORT + "/tcp"}, + NetworkID: bridgeNetworkName, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, pool.Purge(pgbouncer)) + }) + + b.pgbouncerProxy = &container{ + hostHostname: "localhost", + hostPort: pgbouncer.GetPort(PGBOUNCER_TEST_PORT + "/tcp"), + containerHostname: pgbouncerContainerHostname, + containerPort: PGBOUNCER_TEST_PORT, + } +} + +func (b *postgresTester) initializeHostConnection(t testing.TB, pool *dockertest.Pool) (conn *pgx.Conn) { + hostname, port := b.getHostHostnameAndPort() + uri := fmt.Sprintf("postgresql://%s@%s:%s/?sslmode=disable", b.creds, hostname, port) + err := pool.Retry(func() error { + var err error + ctx, cancelConnect := context.WithTimeout(context.Background(), dockerBootTimeout) + defer cancelConnect() + conn, err = pgx.Connect(ctx, uri) + if err != nil { + return err + } + return nil + }) + require.NoError(t, err) + return conn +} + +func (b *postgresTester) getHostnameAndPort() (string, string) { + // If a bridgeNetworkName is supplied then we will return the container + // hostname and port that is resolvable from within the container network. + // If bridgeNetworkName is not supplied then the hostname and port will be + // resolvable from the host. + if b.useContainerHostname { + return b.getContainerHostnameAndPort() + } + return b.getHostHostnameAndPort() +} + +func (b *postgresTester) getHostHostnameAndPort() (string, string) { + if b.pgbouncerProxy != nil { + return b.pgbouncerProxy.hostHostname, b.pgbouncerProxy.hostPort + } + return b.hostHostname, b.hostPort +} + +func (b *postgresTester) getContainerHostnameAndPort() (string, string) { + if b.pgbouncerProxy != nil { + return b.pgbouncerProxy.containerHostname, b.pgbouncerProxy.containerPort + } + return b.containerHostname, b.containerPort +}