From c16a5d8dce6a1117e1e2a2639085814df401ce69 Mon Sep 17 00:00:00 2001 From: Zsolt Takacs Date: Thu, 5 Aug 2021 15:36:49 +0200 Subject: [PATCH] use type params in cache key so multiple instances can be correctly derived for generic classes --- src/main/scala/io/findify/flinkadt/api/package.scala | 5 +++-- .../scala/io/findify/flinkadt/SerializerTest.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/main/scala/io/findify/flinkadt/api/package.scala b/src/main/scala/io/findify/flinkadt/api/package.scala index 379708c..3c39fb7 100644 --- a/src/main/scala/io/findify/flinkadt/api/package.scala +++ b/src/main/scala/io/findify/flinkadt/api/package.scala @@ -42,7 +42,8 @@ package object api extends LowPrioImplicits { def combine[T <: Product: ClassTag: TypeTag]( ctx: CaseClass[TypeInformation, T] ): TypeInformation[T] = { - cache.get(ctx.typeName.full) match { + val cacheKey = s"${ctx.typeName.full}_${ctx.typeName.typeArguments}" + cache.get(cacheKey) match { case Some(cached) => cached.asInstanceOf[TypeInformation[T]] case None => val clazz = classTag[T].runtimeClass.asInstanceOf[Class[T]] @@ -60,7 +61,7 @@ package object api extends LowPrioImplicits { params = ctx.parameters, ser = serializer ) - cache.put(ctx.typeName.full, ti) + cache.put(cacheKey, ti) ti } } diff --git a/src/test/scala/io/findify/flinkadt/SerializerTest.scala b/src/test/scala/io/findify/flinkadt/SerializerTest.scala index 66b5839..e5e0f8b 100644 --- a/src/test/scala/io/findify/flinkadt/SerializerTest.scala +++ b/src/test/scala/io/findify/flinkadt/SerializerTest.scala @@ -12,6 +12,7 @@ import io.findify.flinkadt.SerializerTest.{ Bar2, Foo, Foo2, + Generic, ListADT, Nested, Node, @@ -142,6 +143,13 @@ class SerializerTest extends AnyFlatSpec with Matchers with Inspectors { roundtrip(ser, ListADT(List(Foo("a")))) } + it should "derive multiple instances of generic class" in { + val ser = implicitly[TypeInformation[Generic[SimpleOption]]].createSerializer(null) + val ser2 = implicitly[TypeInformation[Generic[Simple]]].createSerializer(null) + all(ser, Generic(SimpleOption(None), Bar(0))) + all(ser2, Generic(Simple(0, "asd"), Bar(0))) + } + def roundtrip[T](ser: TypeSerializer[T], in: T) = { val out = new ByteArrayOutputStream() ser.serialize(in, new DataOutputViewStreamWrapper(out)) @@ -219,5 +227,7 @@ object SerializerTest { case class SimpleOption(a: Option[String]) + case class Generic[T](a: T, b: ADT) + case class ListADT(a: List[ADT]) }