diff --git a/magnum/src/main/scala/com/augustnagro/magnum/util.scala b/magnum/src/main/scala/com/augustnagro/magnum/util.scala index d5e10c5..88ac68a 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/util.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/util.scala @@ -67,31 +67,45 @@ private def sqlImpl(sc: Expr[StringContext], args: Expr[Seq[Any]])(using // val stringExprs: Seq[Expr[String]] = sc match // case '{ StringContext(${ Varargs(strings) }: _*) } => strings - val argsExprs = allArgsExprs.filter { - case '{ $arg: SqlLiteral } => false - case _ => true - } - '{ - val allArgs = ${ Expr.ofSeq(allArgsExprs) } - val sqlQueryReprs = ${ - queryReprs(allArgsExprs, '{ allArgs }, '{ Vector.newBuilder }) + val args: Seq[Any] = ${ Expr.ofSeq(allArgsExprs) } + + val sqlQueryReprs: Vector[String] = ${ + queryReprs(allArgsExprs, '{ args }, '{ Vector.newBuilder }) + } + val queryExpr: String = $sc.s(sqlQueryReprs: _*) + + val flattenedArgs: Vector[Any] = ${ + flattenedArgsExpr(allArgsExprs, '{ args }, '{ Vector.newBuilder }) } - val queryExpr = $sc.s(sqlQueryReprs: _*) - val args = allArgs.filter: - case _: SqlLiteral => false - case _ => true - val flattenedArgs = args.map: - case frag: Frag => frag.params - case x => x val writer: FragWriter = (ps: PreparedStatement, pos: Int) => { - ${ sqlWriter('{ ps }, '{ pos }, '{ args }, argsExprs, '{ 0 }) } + ${ sqlWriter('{ ps }, '{ pos }, '{ args }, allArgsExprs) } } Frag(queryExpr, flattenedArgs, writer) } end sqlImpl +private def flattenedArgsExpr( + argsExprs: Seq[Expr[Any]], + allArgs: Expr[Seq[Any]], + builder: Expr[m.Builder[Any, Vector[Any]]], + i: Int = 0 +)(using Quotes): Expr[Vector[Any]] = + argsExprs match + case '{ $arg: SqlLiteral } +: tail => + flattenedArgsExpr(tail, allArgs, builder, i + 1) + case '{ $arg: Frag } +: tail => + val newBuilder = '{ + $builder ++= $allArgs(${ Expr(i) }).asInstanceOf[Frag].params + } + flattenedArgsExpr(tail, allArgs, newBuilder, i + 1) + case '{ $arg: tp } +: tail => + val newBuilder = '{ $builder += $allArgs(${ Expr(i) }) } + flattenedArgsExpr(tail, allArgs, newBuilder, i + 1) + case Seq() => + '{ $builder.result() } + private def queryReprs( argsExprs: Seq[Expr[Any]], allArgs: Expr[Seq[Any]], @@ -121,35 +135,29 @@ private def sqlWriter( posExpr: Expr[Int], args: Expr[Seq[Any]], argsExprs: Seq[Expr[Any]], - iExpr: Expr[Int] + i: Int = 0 )(using Quotes): Expr[Int] = import quotes.reflect.* argsExprs match - case head +: tail => - head match - case '{ $arg: Frag } => - '{ - val i = $iExpr - val frag = $args(i).asInstanceOf[Frag] - val pos = $posExpr - val newPos = frag.writer.write($psExpr, pos) - val newI = i + 1 - ${ sqlWriter(psExpr, '{ newPos }, args, tail, '{ newI }) } - } - case '{ $arg: tp } => - val codecExpr = summonWriter[tp] - '{ - val i = $iExpr - val argValue = $args(i).asInstanceOf[tp] - val pos = $posExpr - val codec = $codecExpr - codec.writeSingle(argValue, $psExpr, pos) - val newPos = pos + codec.cols.length - val newI = i + 1 - ${ sqlWriter(psExpr, '{ newPos }, args, tail, '{ newI }) } - } - case _ => - report.errorAndAbort("Args must be explicit", head) + case '{ $arg: SqlLiteral } +: tail => + sqlWriter(psExpr, posExpr, args, tail, i + 1) + case '{ $arg: Frag } +: tail => + '{ + val frag = $args(${ Expr(i) }).asInstanceOf[Frag] + val pos = $posExpr + val newPos = frag.writer.write($psExpr, pos) + ${ sqlWriter(psExpr, '{ newPos }, args, tail, i + 1) } + } + case '{ $arg: tp } +: tail => + val codecExpr = summonWriter[tp] + '{ + val argValue = $args(${ Expr(i) }).asInstanceOf[tp] + val pos = $posExpr + val codec = $codecExpr + codec.writeSingle(argValue, $psExpr, pos) + val newPos = pos + codec.cols.length + ${ sqlWriter(psExpr, '{ newPos }, args, tail, i + 1) } + } case Seq() => posExpr end match end sqlWriter diff --git a/magnum/src/test/scala/shared/EmbeddedFragTests.scala b/magnum/src/test/scala/shared/EmbeddedFragTests.scala index 6ed015e..fac9345 100644 --- a/magnum/src/test/scala/shared/EmbeddedFragTests.scala +++ b/magnum/src/test/scala/shared/EmbeddedFragTests.scala @@ -3,6 +3,8 @@ package shared import com.augustnagro.magnum.* import munit.FunSuite +import java.util.UUID + def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( using munit.Location ): Unit = @@ -10,7 +12,7 @@ def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( test("embed Frag into Frag"): def findPersonCnt(filter: Frag)(using DbCon): Int = - val x = sql"first_name IS NOT NULL" + val x = sql"id != ${util.Random.nextInt(20) + 20}" sql"SELECT count(*) FROM person WHERE $filter AND $x" .query[Int] .run() @@ -22,3 +24,15 @@ def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'") assert(johnCnt == 2) + + test("embedded frag param exprs should be evaluated only once"): + object Holder: + var uuid: UUID = _ + def set(uuid: UUID): UUID = + this.uuid = uuid + uuid + val frag = + sql"select * from person where ${sql"social_id = ${Holder.set(UUID.randomUUID)}"}" + assert(frag.params.size == 1) + assert(frag.params.head == Holder.uuid) +end embeddedFragTests