diff --git a/build/common_jvm_maven.bzl b/build/common_jvm_maven.bzl index f7ae9ed7b..e7694fca4 100644 --- a/build/common_jvm_maven.bzl +++ b/build/common_jvm_maven.bzl @@ -88,7 +88,7 @@ def common_jvm_maven_artifacts_dict(): "com.google.cloud.sql:cloud-sql-connector-r2dbc-postgres": "1.6.2", "org.postgresql:postgresql": "42.4.0", "org.postgresql:r2dbc-postgresql": "0.9.1.RELEASE", - "com.opentable.components:otj-pg-embedded": "1.0.1", + "org.testcontainers:postgresql": "1.18.3", # Liquibase. "org.yaml:snakeyaml": "1.30", diff --git a/imports/java/com/opentable/db/postgres/BUILD.bazel b/imports/java/com/opentable/db/postgres/BUILD.bazel deleted file mode 100644 index 5973c3783..000000000 --- a/imports/java/com/opentable/db/postgres/BUILD.bazel +++ /dev/null @@ -1,6 +0,0 @@ -package(default_visibility = ["//visibility:public"]) - -alias( - name = "pg_embedded", - actual = "@maven//:com_opentable_components_otj_pg_embedded", -) diff --git a/imports/java/org/testcontainers/containers/BUILD.bazel b/imports/java/org/testcontainers/containers/BUILD.bazel new file mode 100644 index 000000000..0816861ce --- /dev/null +++ b/imports/java/org/testcontainers/containers/BUILD.bazel @@ -0,0 +1,6 @@ +package(default_visibility = ["//visibility:public"]) + +alias( + name = "postgresql", + actual = "@maven//:org_testcontainers_postgresql", +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/BUILD.bazel index 58f9ba510..562d2f87b 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/BUILD.bazel @@ -6,14 +6,14 @@ package( ) kt_jvm_library( - name = "embedded_postgres", - srcs = ["EmbeddedPostgresDatabaseProvider.kt"], + name = "database_provider", + srcs = ["PostgresDatabaseProvider.kt"], runtime_deps = [ "//imports/java/liquibase/ext:postgresql", ], deps = [ - "//imports/java/com/opentable/db/postgres:pg_embedded", - "//src/main/kotlin/org/wfanet/measurement/common", + "//imports/java/org/testcontainers/containers:postgresql", + "//imports/kotlin/kotlinx/coroutines/reactive", "//src/main/kotlin/org/wfanet/measurement/common/db/liquibase", "//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", ], diff --git a/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/EmbeddedPostgresDatabaseProvider.kt b/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/EmbeddedPostgresDatabaseProvider.kt deleted file mode 100644 index d5f75b73b..000000000 --- a/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/EmbeddedPostgresDatabaseProvider.kt +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.common.db.r2dbc.postgres.testing - -import com.opentable.db.postgres.embedded.DatabasePreparer -import com.opentable.db.postgres.embedded.PreparedDbProvider -import io.r2dbc.spi.ConnectionFactories -import io.r2dbc.spi.ConnectionFactoryOptions -import java.net.URI -import java.nio.file.Path -import java.util.logging.Level -import javax.sql.DataSource -import kotlinx.coroutines.reactive.awaitFirst -import liquibase.Contexts -import liquibase.Scope -import org.wfanet.measurement.common.db.liquibase.Liquibase -import org.wfanet.measurement.common.db.liquibase.setLogLevel -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.queryMap - -class EmbeddedPostgresDatabaseProvider(changelogPath: Path) { - private val dbProvider: PreparedDbProvider = - PreparedDbProvider.forPreparer(LiquibasePreparer(changelogPath)) - - fun createNewDatabase(): PostgresDatabaseClient { - val jdbcConnectionUri = URI(dbProvider.createDatabase().removePrefix(JDBC_SCHEME_PREFIX)) - val queryParams: Map = jdbcConnectionUri.queryMap - val connectionFactory = - ConnectionFactories.get( - ConnectionFactoryOptions.builder() - .option(ConnectionFactoryOptions.DRIVER, "postgresql") - .option(ConnectionFactoryOptions.HOST, jdbcConnectionUri.host) - .option(ConnectionFactoryOptions.PORT, jdbcConnectionUri.port) - .option(ConnectionFactoryOptions.USER, queryParams.getValue("user")) - .option(ConnectionFactoryOptions.PASSWORD, queryParams.getValue("password")) - .option(ConnectionFactoryOptions.DATABASE, jdbcConnectionUri.path.trimStart('/')) - .build() - ) - - return PostgresDatabaseClient { connectionFactory.create().awaitFirst() } - } - - companion object { - private const val JDBC_SCHEME_PREFIX = "jdbc:" - } - - private class LiquibasePreparer(private val changelogPath: Path) : DatabasePreparer { - override fun prepare(ds: DataSource) { - ds.connection.use { connection -> - Liquibase.fromPath(connection, changelogPath).use { liquibase -> - Scope.getCurrentScope().setLogLevel(Level.FINE) - liquibase.update(Contexts()) - } - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/PostgresDatabaseProvider.kt b/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/PostgresDatabaseProvider.kt new file mode 100644 index 000000000..3d0cf829d --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing/PostgresDatabaseProvider.kt @@ -0,0 +1,107 @@ +/* + * Copyright 2023 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.db.r2dbc.postgres.testing + +import io.r2dbc.spi.ConnectionFactories +import io.r2dbc.spi.ConnectionFactoryOptions +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicInteger +import java.util.logging.Level +import kotlinx.coroutines.reactive.awaitFirst +import liquibase.Contexts +import liquibase.Scope +import org.junit.rules.TestRule +import org.junit.runner.Description +import org.junit.runners.model.Statement +import org.testcontainers.containers.PostgreSQLContainer +import org.wfanet.measurement.common.db.liquibase.Liquibase +import org.wfanet.measurement.common.db.liquibase.setLogLevel +import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient + +interface PostgresDatabaseProvider { + /** Creates a new database within the instance. */ + fun createDatabase(): PostgresDatabaseClient +} + +/** + * [TestRule] which provides PostgreSQL databases. + * + * This is intended to be used as a [org.junit.ClassRule]. + */ +class PostgresDatabaseProviderRule(private val changelogPath: Path) : + PostgresDatabaseProvider, TestRule { + private val postgresContainer = KPostgresContainer(POSTGRES_IMAGE_NAME) + + override fun createDatabase() = postgresContainer.createDatabase() + + override fun apply(base: Statement, description: Description): Statement { + val dbProviderStatement = + object : Statement() { + override fun evaluate() { + postgresContainer.updateTemplateDatabase(changelogPath) + base.evaluate() + } + } + return (postgresContainer as TestRule).apply(dbProviderStatement, description) + } + + companion object { + /** Name of PostgreSQL Docker image. */ + private const val POSTGRES_IMAGE_NAME = "postgres:15" + private const val TEMPLATE_DATABASE_NAME = "template1" + + private val dbNumber = AtomicInteger() + + private fun KPostgresContainer.updateTemplateDatabase(changelogPath: Path) { + withDatabaseName(TEMPLATE_DATABASE_NAME).createConnection("").use { connection -> + Liquibase.fromPath(connection, changelogPath).use { liquibase -> + Scope.getCurrentScope().setLogLevel(Level.FINE) + liquibase.update(Contexts()) + } + } + } + + private fun KPostgresContainer.createDatabase(): PostgresDatabaseClient { + val dbNumber = dbNumber.incrementAndGet() + val databaseName = "database_$dbNumber" + createConnection("").use { connection -> + connection.createStatement().use { it.execute("CREATE DATABASE $databaseName") } + } + + val connectionFactory = + ConnectionFactories.get( + ConnectionFactoryOptions.builder() + .option(ConnectionFactoryOptions.DRIVER, "postgresql") + .option(ConnectionFactoryOptions.HOST, host) + .option( + ConnectionFactoryOptions.PORT, + getMappedPort(PostgreSQLContainer.POSTGRESQL_PORT) + ) + .option(ConnectionFactoryOptions.USER, username) + .option(ConnectionFactoryOptions.PASSWORD, password) + .option(ConnectionFactoryOptions.DATABASE, databaseName) + .build() + ) + + return PostgresDatabaseClient { connectionFactory.create().awaitFirst() } + } + } +} + +/** Kotlin generic type for [PostgreSQLContainer]. */ +private class KPostgresContainer(imageName: String) : + PostgreSQLContainer(imageName) diff --git a/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/BUILD.bazel index 2f04c5127..5b2b8b21c 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/BUILD.bazel @@ -15,7 +15,7 @@ kt_jvm_test( "//imports/java/org/junit", "//src/main/kotlin/org/wfanet/measurement/common", "//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - "//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:embedded_postgres", + "//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", "//src/main/proto/google/type:dayofweek_kt_jvm_proto", "//src/main/proto/google/type:latlng_kt_jvm_proto", ], diff --git a/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/PostgresDatabaseClientTest.kt b/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/PostgresDatabaseClientTest.kt index eca765961..fe7210649 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/PostgresDatabaseClientTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/PostgresDatabaseClientTest.kt @@ -25,19 +25,20 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking +import org.junit.ClassRule import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.measurement.common.db.r2dbc.ReadWriteContext import org.wfanet.measurement.common.db.r2dbc.ResultRow import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresDatabaseProvider +import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule import org.wfanet.measurement.common.getJarResourcePath import org.wfanet.measurement.common.identity.InternalId @RunWith(JUnit4::class) class PostgresDatabaseClientTest { - private val dbClient = dbProvider.createNewDatabase() + private val dbClient = databaseProvider.createDatabase() @Test fun `executeStatement returns result with updated rows`() { @@ -215,7 +216,8 @@ class PostgresDatabaseClientTest { companion object { private val CHANGELOG_PATH: Path = this::class.java.classLoader.getJarResourcePath("db/postgres/changelog.yaml")!! - private val dbProvider = EmbeddedPostgresDatabaseProvider(CHANGELOG_PATH) + + @get:ClassRule @JvmStatic val databaseProvider = PostgresDatabaseProviderRule(CHANGELOG_PATH) } }