Skip to content

Commit

Permalink
Use TestContainers instead of OpenTable embedded PostgreSQL.
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas committed Aug 14, 2023
1 parent d6e938e commit 56831e6
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 86 deletions.
2 changes: 1 addition & 1 deletion build/common_jvm_maven.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def common_jvm_maven_artifacts_dict():
"com.google.cloud.sql:cloud-sql-connector-r2dbc-postgres": "1.6.2",
"org.postgresql:postgresql": "42.4.0",
"org.postgresql:r2dbc-postgresql": "0.9.1.RELEASE",
"com.opentable.components:otj-pg-embedded": "1.0.1",
"org.testcontainers:postgresql": "1.18.3",

# Liquibase.
"org.yaml:snakeyaml": "1.30",
Expand Down
6 changes: 0 additions & 6 deletions imports/java/com/opentable/db/postgres/BUILD.bazel

This file was deleted.

6 changes: 6 additions & 0 deletions imports/java/org/testcontainers/containers/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package(default_visibility = ["//visibility:public"])

alias(
name = "postgresql",
actual = "@maven//:org_testcontainers_postgresql",
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ package(
)

kt_jvm_library(
name = "embedded_postgres",
srcs = ["EmbeddedPostgresDatabaseProvider.kt"],
name = "database_provider",
srcs = ["PostgresDatabaseProvider.kt"],
runtime_deps = [
"//imports/java/liquibase/ext:postgresql",
],
deps = [
"//imports/java/com/opentable/db/postgres:pg_embedded",
"//src/main/kotlin/org/wfanet/measurement/common",
"//imports/java/org/testcontainers/containers:postgresql",
"//imports/kotlin/kotlinx/coroutines/reactive",
"//src/main/kotlin/org/wfanet/measurement/common/db/liquibase",
"//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres",
],
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2023 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.db.r2dbc.postgres.testing

import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactoryOptions
import java.nio.file.Path
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Level
import kotlinx.coroutines.reactive.awaitFirst
import liquibase.Contexts
import liquibase.Scope
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import org.testcontainers.containers.PostgreSQLContainer
import org.wfanet.measurement.common.db.liquibase.Liquibase
import org.wfanet.measurement.common.db.liquibase.setLogLevel
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient

/** Provider of PostgreSQL databases. */
interface PostgresDatabaseProvider {
/** Creates a new database within the instance. */
fun createDatabase(): PostgresDatabaseClient
}

/**
* [PostgresDatabaseProvider] implementation as a JUnit [TestRule].
*
* This is intended to be used as a [org.junit.ClassRule] so that the underlying container is only
* started once.
*/
class PostgresDatabaseProviderRule(private val changelogPath: Path) :
PostgresDatabaseProvider, TestRule {
private val postgresContainer = KPostgresContainer(POSTGRES_IMAGE_NAME)

override fun createDatabase() = postgresContainer.createDatabase()

override fun apply(base: Statement, description: Description): Statement {
val dbProviderStatement =
object : Statement() {
override fun evaluate() {
postgresContainer.updateTemplateDatabase(changelogPath)
base.evaluate()
}
}
return (postgresContainer as TestRule).apply(dbProviderStatement, description)
}

companion object {
/** Name of PostgreSQL Docker image. */
private const val POSTGRES_IMAGE_NAME = "postgres:15"
private const val TEMPLATE_DATABASE_NAME = "template1"

private val dbNumber = AtomicInteger()

private fun KPostgresContainer.updateTemplateDatabase(changelogPath: Path) {
withDatabaseName(TEMPLATE_DATABASE_NAME).createConnection("").use { connection ->
Liquibase.fromPath(connection, changelogPath).use { liquibase ->
Scope.getCurrentScope().setLogLevel(Level.FINE)
liquibase.update(Contexts())
}
}
}

private fun KPostgresContainer.createDatabase(): PostgresDatabaseClient {
val dbNumber = dbNumber.incrementAndGet()
val databaseName = "database_$dbNumber"
createConnection("").use { connection ->
connection.createStatement().use { it.execute("CREATE DATABASE $databaseName") }
}

val connectionFactory =
ConnectionFactories.get(
ConnectionFactoryOptions.builder()
.option(ConnectionFactoryOptions.DRIVER, "postgresql")
.option(ConnectionFactoryOptions.HOST, host)
.option(
ConnectionFactoryOptions.PORT,
getMappedPort(PostgreSQLContainer.POSTGRESQL_PORT)
)
.option(ConnectionFactoryOptions.USER, username)
.option(ConnectionFactoryOptions.PASSWORD, password)
.option(ConnectionFactoryOptions.DATABASE, databaseName)
.build()
)

return PostgresDatabaseClient { connectionFactory.create().awaitFirst() }
}
}
}

/** Kotlin generic type for [PostgreSQLContainer]. */
private class KPostgresContainer(imageName: String) :
PostgreSQLContainer<KPostgresContainer>(imageName)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ kt_jvm_test(
"//imports/java/org/junit",
"//src/main/kotlin/org/wfanet/measurement/common",
"//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres",
"//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:embedded_postgres",
"//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider",
"//src/main/proto/google/type:dayofweek_kt_jvm_proto",
"//src/main/proto/google/type:latlng_kt_jvm_proto",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.junit.ClassRule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.wfanet.measurement.common.db.r2dbc.ReadWriteContext
import org.wfanet.measurement.common.db.r2dbc.ResultRow
import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresDatabaseProvider
import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule
import org.wfanet.measurement.common.getJarResourcePath
import org.wfanet.measurement.common.identity.InternalId

@RunWith(JUnit4::class)
class PostgresDatabaseClientTest {
private val dbClient = dbProvider.createNewDatabase()
private val dbClient = databaseProvider.createDatabase()

@Test
fun `executeStatement returns result with updated rows`() {
Expand Down Expand Up @@ -215,7 +216,8 @@ class PostgresDatabaseClientTest {
companion object {
private val CHANGELOG_PATH: Path =
this::class.java.classLoader.getJarResourcePath("db/postgres/changelog.yaml")!!
private val dbProvider = EmbeddedPostgresDatabaseProvider(CHANGELOG_PATH)

@get:ClassRule @JvmStatic val databaseProvider = PostgresDatabaseProviderRule(CHANGELOG_PATH)
}
}

Expand Down

0 comments on commit 56831e6

Please sign in to comment.