Skip to content

Commit

Permalink
add test to ensure that embedded frag param exprs are evaluated only …
Browse files Browse the repository at this point in the history
…once
  • Loading branch information
AugustNagro committed Dec 1, 2024
1 parent 5dd58e1 commit df3530f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 43 deletions.
92 changes: 50 additions & 42 deletions magnum/src/main/scala/com/augustnagro/magnum/util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,45 @@ private def sqlImpl(sc: Expr[StringContext], args: Expr[Seq[Any]])(using
// val stringExprs: Seq[Expr[String]] = sc match
// case '{ StringContext(${ Varargs(strings) }: _*) } => strings

val argsExprs = allArgsExprs.filter {
case '{ $arg: SqlLiteral } => false
case _ => true
}

'{
val allArgs = ${ Expr.ofSeq(allArgsExprs) }
val sqlQueryReprs = ${
queryReprs(allArgsExprs, '{ allArgs }, '{ Vector.newBuilder })
val args: Seq[Any] = ${ Expr.ofSeq(allArgsExprs) }

val sqlQueryReprs: Vector[String] = ${
queryReprs(allArgsExprs, '{ args }, '{ Vector.newBuilder })
}
val queryExpr: String = $sc.s(sqlQueryReprs: _*)

val flattenedArgs: Vector[Any] = ${
flattenedArgsExpr(allArgsExprs, '{ args }, '{ Vector.newBuilder })
}
val queryExpr = $sc.s(sqlQueryReprs: _*)
val args = allArgs.filter:
case _: SqlLiteral => false
case _ => true
val flattenedArgs = args.map:
case frag: Frag => frag.params
case x => x

val writer: FragWriter = (ps: PreparedStatement, pos: Int) => {
${ sqlWriter('{ ps }, '{ pos }, '{ args }, argsExprs, '{ 0 }) }
${ sqlWriter('{ ps }, '{ pos }, '{ args }, allArgsExprs) }
}
Frag(queryExpr, flattenedArgs, writer)
}
end sqlImpl

private def flattenedArgsExpr(
argsExprs: Seq[Expr[Any]],
allArgs: Expr[Seq[Any]],
builder: Expr[m.Builder[Any, Vector[Any]]],
i: Int = 0
)(using Quotes): Expr[Vector[Any]] =
argsExprs match
case '{ $arg: SqlLiteral } +: tail =>
flattenedArgsExpr(tail, allArgs, builder, i + 1)
case '{ $arg: Frag } +: tail =>
val newBuilder = '{
$builder ++= $allArgs(${ Expr(i) }).asInstanceOf[Frag].params
}
flattenedArgsExpr(tail, allArgs, newBuilder, i + 1)
case '{ $arg: tp } +: tail =>
val newBuilder = '{ $builder += $allArgs(${ Expr(i) }) }
flattenedArgsExpr(tail, allArgs, newBuilder, i + 1)
case Seq() =>
'{ $builder.result() }

private def queryReprs(
argsExprs: Seq[Expr[Any]],
allArgs: Expr[Seq[Any]],
Expand Down Expand Up @@ -121,35 +135,29 @@ private def sqlWriter(
posExpr: Expr[Int],
args: Expr[Seq[Any]],
argsExprs: Seq[Expr[Any]],
iExpr: Expr[Int]
i: Int = 0
)(using Quotes): Expr[Int] =
import quotes.reflect.*
argsExprs match
case head +: tail =>
head match
case '{ $arg: Frag } =>
'{
val i = $iExpr
val frag = $args(i).asInstanceOf[Frag]
val pos = $posExpr
val newPos = frag.writer.write($psExpr, pos)
val newI = i + 1
${ sqlWriter(psExpr, '{ newPos }, args, tail, '{ newI }) }
}
case '{ $arg: tp } =>
val codecExpr = summonWriter[tp]
'{
val i = $iExpr
val argValue = $args(i).asInstanceOf[tp]
val pos = $posExpr
val codec = $codecExpr
codec.writeSingle(argValue, $psExpr, pos)
val newPos = pos + codec.cols.length
val newI = i + 1
${ sqlWriter(psExpr, '{ newPos }, args, tail, '{ newI }) }
}
case _ =>
report.errorAndAbort("Args must be explicit", head)
case '{ $arg: SqlLiteral } +: tail =>
sqlWriter(psExpr, posExpr, args, tail, i + 1)
case '{ $arg: Frag } +: tail =>
'{
val frag = $args(${ Expr(i) }).asInstanceOf[Frag]
val pos = $posExpr
val newPos = frag.writer.write($psExpr, pos)
${ sqlWriter(psExpr, '{ newPos }, args, tail, i + 1) }
}
case '{ $arg: tp } +: tail =>
val codecExpr = summonWriter[tp]
'{
val argValue = $args(${ Expr(i) }).asInstanceOf[tp]
val pos = $posExpr
val codec = $codecExpr
codec.writeSingle(argValue, $psExpr, pos)
val newPos = pos + codec.cols.length
${ sqlWriter(psExpr, '{ newPos }, args, tail, i + 1) }
}
case Seq() => posExpr
end match
end sqlWriter
Expand Down
16 changes: 15 additions & 1 deletion magnum/src/test/scala/shared/EmbeddedFragTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package shared
import com.augustnagro.magnum.*
import munit.FunSuite

import java.util.UUID

def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(
using munit.Location
): Unit =
import suite.*

test("embed Frag into Frag"):
def findPersonCnt(filter: Frag)(using DbCon): Int =
val x = sql"first_name IS NOT NULL"
val x = sql"id != ${util.Random.nextInt(20) + 20}"
sql"SELECT count(*) FROM person WHERE $filter AND $x"
.query[Int]
.run()
Expand All @@ -22,3 +24,15 @@ def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(
val johnCnt =
findPersonCnt(sql"$isAdminFrag AND first_name = 'John'")
assert(johnCnt == 2)

test("embedded frag param exprs should be evaluated only once"):
object Holder:
var uuid: UUID = _
def set(uuid: UUID): UUID =
this.uuid = uuid
uuid
val frag =
sql"select * from person where ${sql"social_id = ${Holder.set(UUID.randomUUID)}"}"
assert(frag.params.size == 1)
assert(frag.params.head == Holder.uuid)
end embeddedFragTests

0 comments on commit df3530f

Please sign in to comment.