Skip to content

Commit

Permalink
add: Cats-Effect intergation for Magnum (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
KaranAhlawat committed Jan 15, 2025
1 parent 683f95a commit 8e4a080
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 0 deletions.
12 changes: 12 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,15 @@ lazy val magnumZio = project
"org.postgresql" % "postgresql" % postgresDriverVersion % Test
)
)

lazy val magnumCats = project
.in(file("magnum-cats-effect"))
.dependsOn(magnum % "compile->compile;test->test")
.settings(
Test / fork := true,
publish / skip := false,
libraryDependencies ++= Seq(
"org.typelevel" %% "cats-effect" % "3.5.7" % Provided,
"org.tpolecat" %% "natchez-core" % "0.3.7" % Provided
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package com.augustnagro.magnum.magcats

import javax.sql.DataSource
import com.augustnagro.magnum.DbCon
import cats.effect.IO
import com.augustnagro.magnum.Transactor
import cats.effect.kernel.Resource
import java.sql.Connection
import com.augustnagro.magnum.DbTx
import cats.effect.unsafe.IORuntime
import scala.util.control.NonFatal
import natchez.Trace

/** Executes a given query on a given DataSource
*
* Re-implementation for IO of
* [[com.augustnagro.magnum.connect(dataSource: DataSource)]]
*
* Usage:
* {{{
* import com.augustnagro.magnum.magcats.*
*
* connect(datasource) { cn ?=> repo.findById(id) }
* }}}
*/
def connect[A](dataSource: DataSource)(q: DbCon ?=> A)(using
trace: Trace[IO]
): IO[A] =
connect(Transactor(dataSource))(q)

/** Executes a given query on a given Transactor
*
* Re-implementation for IO of
* [[com.augustnagro.magnum.connect(dataSource: DataSource)]]
*
* Usage:
* {{{
* import com.augustnagro.magnum.magcats.*
*
* connect(transactor) { cn ?=> repo.findById(id) }
* }}}
*/
def connect[A](
transactor: Transactor
)(q: DbCon ?=> A)(using trace: Trace[IO]): IO[A] =
Resource
.fromAutoCloseable(IO.interruptible(transactor.dataSource.getConnection()))
.use { cn =>
IO.interruptible {
transactor.connectionConfig(cn)
q(using DbCon(cn, transactor.sqlLogger))
}
}

/** Executes a given transaction on a given DataSource
*
* Re-implementation for IO of
* [[com.augustnagro.magnum.transact(transactor: Transactor)]]
*
* Usage:
* {{{
* import com.augustnagro.magnum.magcats.*
*
* transact(dataSource) { tx ?=> repo.insertReturning(creator) }
* }}}
*/
def transact[A](dataSource: DataSource)(q: DbTx ?=> A)(using
trace: Trace[IO]
): IO[A] =
transact(Transactor(dataSource))(q)

/** Executes a given transaction on a given DataSource
*
* Re-implementation for IO of
* [[com.augustnagro.magnum.transact(transactor: Transactor, connectionConfig: Connection => Unit)]]
*
* Usage:
* {{{
* import com.augustnagro.magnum.magcats.*
*
* transact(dataSource, ...) { tx ?=> repo.insertReturning(creator) }
* }}}
*/
def transact[A](dataSource: DataSource, connectionConfig: Connection => Unit)(
q: DbTx ?=> A
)(using trace: Trace[IO]): IO[A] =
val transactor =
Transactor(dataSource = dataSource, connectionConfig = connectionConfig)
transact(transactor)(q)

/** Executes a given transaction on a given Transactor
*
* Re-implementation for IO of
* [[com.augustnagro.magnum.transact(transactor: Transactor)]]
*
* Usage:
* {{{
* import com.augustnagro.magnum.magcats.*
*
* transact(transactor) { tx ?=> repo.insertReturning(creator) }
* }}}
*/
def transact[A](
transactor: Transactor
)(q: DbTx ?=> A)(using trace: Trace[IO]): IO[A] =
Resource
.fromAutoCloseable(IO.interruptible(transactor.dataSource.getConnection()))
.use { cn =>
IO.blocking {
transactor.connectionConfig(cn)
cn.setAutoCommit(false)
try {
val res = q(using DbTx(cn, transactor.sqlLogger))
cn.commit()
res
} catch {
case NonFatal(t) =>
cn.rollback()
throw t
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package com.augustnagro.magnum.magcats

import com.augustnagro.magnum.*
import shared.Color
import munit.FunSuite
import cats.effect.IO
import cats.effect.std.Dispatcher
import cats.effect.unsafe.IORuntime

import java.sql.Connection
import java.time.OffsetDateTime
import scala.util.{Success, Using}
import cats.effect.Trace

def immutableRepoCatsEffectTests(
suite: FunSuite,
dbType: DbType,
xa: () => Transactor
)(using
munit.Location,
DbCodec[OffsetDateTime]
): Unit =
import suite.*

given IORuntime = IORuntime.global
given natchez.Trace[IO] = natchez.Trace.Implicits.noop

def runIO[A](io: IO[A]): A =
io.unsafeRunSync()

@Table(dbType, SqlNameMapper.CamelToSnakeCase)
case class Car(
model: String,
@Id id: Long,
topSpeed: Int,
@SqlName("vin") vinNumber: Option[Int],
color: Color,
created: OffsetDateTime
) derives DbCodec

val carRepo = ImmutableRepo[Car, Long]
val car = TableInfo[Car, Car, Long]

val allCars = Vector(
Car(
model = "McLaren Senna",
id = 1L,
topSpeed = 208,
vinNumber = Some(123),
color = Color.Red,
created = OffsetDateTime.parse("2024-11-24T22:17:30.000000000Z")
),
Car(
model = "Ferrari F8 Tributo",
id = 2L,
topSpeed = 212,
vinNumber = Some(124),
color = Color.Green,
created = OffsetDateTime.parse("2024-11-24T22:17:31.000000000Z")
),
Car(
model = "Aston Martin Superleggera",
id = 3L,
topSpeed = 211,
vinNumber = None,
color = Color.Blue,
created = OffsetDateTime.parse("2024-11-24T22:17:32.000000000Z")
)
)

test("count"):
val count =
runIO:
magcats.connect(xa()):
carRepo.count
assert(count == 3L)

test("existsById"):
val (exists3, exists4) =
runIO:
magcats.connect(xa()):
carRepo.existsById(3L) -> carRepo.existsById(4L)
assert(exists3)
assert(!exists4)

test("findAll"):
val cars =
runIO:
magcats.connect(xa()):
carRepo.findAll
assert(cars == allCars)

test("findById"):
val (exists3, exists4) =
runIO:
magcats.connect(xa()):
carRepo.findById(3L) -> carRepo.findById(4L)
assert(exists3.get == allCars.last)
assert(exists4 == None)

test("findAllByIds"):
assume(dbType != ClickhouseDbType)
assume(dbType != MySqlDbType)
assume(dbType != OracleDbType)
assume(dbType != SqliteDbType)
val ids =
runIO:
magcats.connect(xa()):
carRepo.findAllById(Vector(1L, 3L)).map(_.id)
assert(ids == Vector(1L, 3L))

test("serializable transaction"):
val count =
runIO:
magcats.transact(xa().copy(connectionConfig = withSerializable)):
carRepo.count
assert(count == 3L)

def withSerializable(con: Connection): Unit =
con.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE)

test("select query"):
val minSpeed: Int = 210
val query =
sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed"
.query[Car]
val result =
runIO:
magcats.connect(xa()):
query.run()
assertNoDiff(
query.frag.sqlString,
"select model, id, top_speed, vin, color, created from car where top_speed > ?"
)
assert(query.frag.params == Vector(minSpeed))
assert(result == allCars.tail)

test("select query with aliasing"):
val minSpeed = 210
val cAlias = car.alias("c")
val query =
sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed"
.query[Car]
val result =
runIO:
magcats.connect(xa()):
query.run()
assertNoDiff(
query.frag.sqlString,
"select c.model, c.id, c.top_speed, c.vin, c.color, c.created from car c where c.top_speed > ?"
)
assert(query.frag.params == Vector(minSpeed))
assert(result == allCars.tail)

test("select via option"):
val vin = Some(124)
val cars =
runIO:
magcats.connect(xa()):
sql"select * from car where vin = $vin"
.query[Car]
.run()
assert(cars == allCars.filter(_.vinNumber == vin))

test("tuple select"):
val tuples =
runIO:
magcats.connect(xa()):
sql"select model, color from car where id = 2"
.query[(String, Color)]
.run()
assert(tuples == Vector(allCars(1).model -> allCars(1).color))

test("reads null int as None and not Some(0)"):
val maybeCar =
runIO:
magcats.connect(xa()):
carRepo.findById(3L)
assert(maybeCar.get.vinNumber == None)

test("created timestamps should match"):
val allCars =
runIO:
magcats.connect(xa()):
carRepo.findAll
assert(allCars.map(_.created) == allCars.map(_.created))

test(".query iterator"):
val carsCount =
runIO:
magcats.connect(xa()):
Using.Manager(implicit use =>
val it = sql"SELECT * FROM car".query[Car].iterator()
it.map(_.id).size
)
assert(carsCount == Success(3))

end immutableRepoCatsEffectTests
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.augustnagro.magnum.magzio

import com.augustnagro.magnum.*
import com.dimafeng.testcontainers.PostgreSQLContainer
import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures
import munit.{AnyFixture, FunSuite, Location}
import org.postgresql.ds.PGSimpleDataSource
import org.testcontainers.utility.DockerImageName

import java.nio.file.{Files, Path}
import scala.util.Using
import scala.util.Using.Manager
import com.augustnagro.magnum.magcats.immutableRepoCatsEffectTests

class PgCatsEffectTests extends FunSuite, TestContainersFixtures:

immutableRepoCatsEffectTests(this, PostgresDbType, xa)

val pgContainer = ForAllContainerFixture(
PostgreSQLContainer
.Def(dockerImageName = DockerImageName.parse("postgres:17.0"))
.createContainer()
)

override def munitFixtures: Seq[AnyFixture[_]] =
super.munitFixtures :+ pgContainer

def xa(): Transactor =
val ds = PGSimpleDataSource()
val pg = pgContainer()
ds.setUrl(pg.jdbcUrl)
ds.setUser(pg.username)
ds.setPassword(pg.password)
val tableDDLs = Vector(
"/pg/car.sql",
"/pg/person.sql",
"/pg/my-user.sql",
"/pg/no-id.sql",
"/pg/big-dec.sql"
).map(p => Files.readString(Path.of(getClass.getResource(p).toURI)))

Manager(use =>
val con = use(ds.getConnection)
val stmt = use(con.createStatement)
for ddl <- tableDDLs do stmt.execute(ddl)
).get
Transactor(ds)
end xa
end PgCatsEffectTests

0 comments on commit 8e4a080

Please sign in to comment.