Skip to content

Commit

Permalink
Support custom discriminator in polymorphic forms
Browse files Browse the repository at this point in the history
  • Loading branch information
sake92 committed Oct 14, 2024
1 parent c9a1a68 commit 64b75e8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 5 deletions.
17 changes: 13 additions & 4 deletions formson/src/ba/sake/formson/FormDataRW.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ object FormDataRW {
private def derivedMacro[T: Type](using Quotes): Expr[FormDataRW[T]] = {
import quotes.reflect.*

def isAnnotation(a: quotes.reflect.Term): Boolean =
a.tpe.typeSymbol.maybeOwner.isNoSymbol ||
a.tpe.typeSymbol.owner.fullName != "scala.annotation.internal"

// only summon ProductOf ??
val mirror: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].getOrElse {
report.errorAndAbort(
Expand Down Expand Up @@ -351,28 +355,33 @@ object FormDataRW {
} else {
val rwInstancesExpr = summonInstances[T, elementTypes]
val rwInstances = Expr.ofList(rwInstancesExpr)
val annotations = Expr.ofList(TypeRepr.of[T].typeSymbol.annotations.filter(isAnnotation).map(_.asExpr))

'{
val discrOpt = $annotations.find(_.isInstanceOf[discriminator]).map(_.asInstanceOf[discriminator])
val discrName = discrOpt.map(_.name).getOrElse("@type")
new FormDataRW[T] {
override def write(path: String, value: T): FormData =
val index = $m.ordinal(value)
val typeName = $labels(index)
val rw = $rwInstances(index)
val res = rw.asInstanceOf[FormDataRW[Any]].write(path, value).asInstanceOf[FormData.Obj]
val newValues = res.values + ("@type" -> FormData.Simple(FormValue.Str(typeName)))
val newValues = res.values + (discrName -> FormData.Simple(FormValue.Str(typeName)))
res.copy(values = newValues)

override def parse(path: String, formData: FormData): T =

val tpeNameOpt = formData
.asInstanceOf[FormData.Obj]
.values
.get("@type")
.get(discrName)
.map {
case FormData.Simple(FormValue.Str(value)) => value
case FormData.Sequence(Seq(FormData.Simple(FormValue.Str(value)), _*)) => value
case other => throw ParsingException(
ParseError(
path,
s"@type has wrong type: '$other'."
s"${discrName} has wrong type: '$other'."
)
)
}
Expand All @@ -388,7 +397,7 @@ object FormDataRW {
)
val rw = $rwInstances(idx)
rw.parse(path, formData).asInstanceOf[T]
case None => throw ParsingException(ParseError(path, "@type not present"))
case None => throw ParsingException(ParseError(path, s"${discrName} not present"))
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions formson/src/ba/sake/formson/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,5 @@ case class ParseError(
case None => s"Key '$path' $msg"
}
}

case class discriminator(name: String) extends scala.annotation.StaticAnnotation
8 changes: 8 additions & 0 deletions formson/test/src/ba/sake/formson/FormDataParseSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ class FormDataParseSuite extends munit.FunSuite {
Sealed1.Case1("bla", 42)
)
)
// custom discriminator
assertEquals(
SeqMap(
"tip" -> Seq(FormValue.Str("B")),
"x" -> Seq(FormValue.Str("bla"))
).parseFormDataMap[Annot1],
Annot1.B("bla")
)
}

test("parseFormDataMap should throw nice errors") {
Expand Down
9 changes: 9 additions & 0 deletions formson/test/src/ba/sake/formson/FormDataWriteSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ class FormDataWriteSuite extends munit.FunSuite {
"nest.@type" -> Seq(FormValue.Str("Case1"))
)
)
// custom discriminator
val annot = Annot1.B("bla")
assertEquals(
annot.toFormDataMap(cfgObjDots),
SeqMap(
"tip" -> Seq(FormValue.Str("B")),
"x" -> Seq(FormValue.Str("bla"))
)
)
}

}
2 changes: 1 addition & 1 deletion formson/test/src/ba/sake/formson/types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object Sealed1 {

case class NestedSealed1(nest: Sealed1) derives FormDataRW

// @ TODO
@discriminator("tip")
enum Annot1 derives FormDataRW:
case A
case B(x: String)
Expand Down

0 comments on commit 64b75e8

Please sign in to comment.