Skip to content

Commit

Permalink
Add DatabaseConnection#inTransaction
Browse files Browse the repository at this point in the history
  • Loading branch information
MineKing9534 committed Oct 18, 2024
1 parent 91731cc commit 46f2af2
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
20 changes: 20 additions & 0 deletions core/src/main/kotlin/de/mineking/database/DatabaseConnection.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package de.mineking.database

import org.jdbi.v3.core.Handle
import org.jdbi.v3.core.Jdbi
import org.jdbi.v3.core.kotlin.inTransactionUnchecked
import org.jdbi.v3.core.kotlin.withHandleUnchecked
import java.lang.reflect.Proxy
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
Expand All @@ -9,6 +12,8 @@ import kotlin.reflect.KType
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.javaField

internal val CURRENT_TRANSACTION: ThreadLocal<Handle> = ThreadLocal()

inline fun <reified T: Annotation> KProperty<*>.getDatabaseAnnotation(): T? = this.javaField?.getAnnotation(T::class.java)
inline fun <reified T: Annotation> KProperty<*>.hasDatabaseAnnotation(): Boolean = getDatabaseAnnotation<T>() != null

Expand Down Expand Up @@ -94,4 +99,19 @@ abstract class DatabaseConnection(

@Suppress("UNCHECKED_CAST")
fun <T: Any> getCachedTable(name: String): Table<T> = (tables[name] ?: throw IllegalArgumentException("Table $name not found")) as Table<T>

fun <R> inTransaction(action: (Handle) -> R): R = driver.inTransactionUnchecked { handle ->
try {
CURRENT_TRANSACTION.set(handle)
action(handle)
} finally {
CURRENT_TRANSACTION.set(null)
}
}

fun <R> execute(action: (Handle) -> R): R {
val transaction = CURRENT_TRANSACTION.get()
return if (transaction != null) action(transaction)
else driver.withHandleUnchecked { action(it) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package de.mineking.database.vendors
import de.mineking.database.*
import de.mineking.database.vendors.PostgresConnection.Companion.logger
import org.jdbi.v3.core.kotlin.useHandleUnchecked
import org.jdbi.v3.core.kotlin.withHandleUnchecked
import org.jdbi.v3.core.result.ResultIterable
import org.jdbi.v3.core.statement.UnableToExecuteStatementException
import org.jdbi.v3.core.statement.Update
Expand Down Expand Up @@ -59,7 +58,7 @@ class PostgresTable<T: Any>(
${ where.format(structure) }
""".trim().replace("\\s+".toRegex(), " ")

return structure.manager.driver.withHandleUnchecked { it.createQuery(sql)
return structure.manager.execute { it.createQuery(sql)
.bindMap(where.values(structure))
.mapTo(Int::class.java)
.first()
Expand Down Expand Up @@ -110,7 +109,7 @@ class PostgresTable<T: Any>(
val sql = createSelect(columnList.joinToString { "\"${it.first}\".\"${it.second}\" as \"${it.first}.${it.second}\"" }, where, order, limit, offset)
return object : RowQueryResult<T> {
override val instance: () -> T = this@PostgresTable.instance
override fun <O> execute(handler: ((T) -> Boolean) -> O): O = structure.manager.driver.withHandleUnchecked { it.createQuery(sql)
override fun <O> execute(handler: ((T) -> Boolean) -> O): O = structure.manager.execute { it.createQuery(sql)
.bindMap(where.values(structure))
.execute { stmt, _ ->
val statement = stmt.get()
Expand Down Expand Up @@ -140,7 +139,7 @@ class PostgresTable<T: Any>(

val sql = createSelect((columnList.map { "\"${ it.first }\".\"${ it.second }\" as \"${ it.first }.${ it.second }\"" } + "(${ target.format(structure) }) as \"value\"").joinToString(), where, order, limit, offset)
return object : ValueQueryResult<C> {
override fun <O> execute(handler: (ResultIterable<C>) -> O): O = structure.manager.driver.withHandleUnchecked { handler(it.createQuery(sql)
override fun <O> execute(handler: (ResultIterable<C>) -> O): O = structure.manager.execute { handler(it.createQuery(sql)
.bindMap(target.values(structure, column?.column))
.bindMap(where.values(structure))
.map { set, _ -> mapper.read(column?.column?.getRootColumn(), type, ReadContext(it, structure, set, columnList.map { "${ it.first }.${ it.second }" } + "value", autofillPrefix = { it != "value" }, shouldRead = false), "value") }
Expand Down Expand Up @@ -190,7 +189,7 @@ class PostgresTable<T: Any>(
""".trim().replace("\\s+".toRegex(), " ")

return createResult {
structure.manager.driver.withHandleUnchecked { executeUpdate(it.createUpdate(sql).bindMap(identity.values(structure)), obj) }
structure.manager.execute { executeUpdate(it.createUpdate(sql).bindMap(identity.values(structure)), obj) }
if (obj is DataObject<*>) obj.afterRead()
obj
}
Expand All @@ -208,7 +207,7 @@ class PostgresTable<T: Any>(
${ where.format(structure) }
""".trim().replace("\\s+".toRegex(), " ")

return createResult { structure.manager.driver.withHandleUnchecked { it.createUpdate(sql)
return createResult { structure.manager.execute { it.createUpdate(sql)
.bindMap(value.values(structure, spec.column))
.bindMap(where.values(structure))
.execute()
Expand All @@ -234,7 +233,7 @@ class PostgresTable<T: Any>(
""".trim().replace("\\s+".toRegex(), " ")

return createResult {
structure.manager.driver.withHandleUnchecked { executeUpdate(it.createUpdate(sql), obj) }
structure.manager.execute { executeUpdate(it.createUpdate(sql), obj) }
if (obj is DataObject<*>) obj.afterRead()
obj
}
Expand Down Expand Up @@ -264,7 +263,7 @@ class PostgresTable<T: Any>(
""".trim().replace("\\s+".toRegex(), " ")

return createResult {
structure.manager.driver.withHandleUnchecked { executeUpdate(it.createUpdate(sql), obj) }
structure.manager.execute { executeUpdate(it.createUpdate(sql), obj) }
if (obj is DataObject<*>) obj.afterRead()
obj
}
Expand All @@ -275,7 +274,7 @@ class PostgresTable<T: Any>(
*/
override fun delete(where: Where): Int {
val sql = "delete from ${ structure.name } ${ where.format(structure) }"
return structure.manager.driver.withHandleUnchecked { it.createUpdate(sql)
return structure.manager.execute { it.createUpdate(sql)
.bindMap(where.values(structure))
.execute()
}
Expand Down
40 changes: 40 additions & 0 deletions postgres/src/test/kotlin/tests/postgres/general/Transaction.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package tests.postgres.general

import de.mineking.database.vendors.PostgresConnection
import org.junit.jupiter.api.Test
import setup.ConsoleSqlLogger
import setup.UserDao
import setup.recreate
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class TransactionTest {
val connection = PostgresConnection("localhost:5432/test", user = "test", password = "test")
val table = connection.getTable(name = "basic_test") { UserDao() }

init {
table.recreate()

connection.driver.setSqlLogger(ConsoleSqlLogger)
}

@Test
fun default() {
connection.inTransaction {
table.insert(UserDao(name = "Tom", email = "[email protected]", age = 12))
}

assertEquals(1, table.selectRowCount())
}

@Test
fun rollback() {
connection.inTransaction {
table.insert(UserDao(name = "Tom", email = "[email protected]", age = 12))
it.rollback()
}

assertEquals(0, table.selectRowCount())
}
}

0 comments on commit 46f2af2

Please sign in to comment.