From 232e92375f76da2433b1ce534511cf0f120be455 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 18 Dec 2023 12:17:33 +0100 Subject: [PATCH] Parquet map logical type --- .../magnolify/parquet/ParquetField.scala | 59 +++++++++++++++++++ .../magnolify/parquet/ParquetTypeSuite.scala | 5 ++ .../magnolify/tools/ParquetParserSuite.scala | 47 +++++++++------ 3 files changed, 92 insertions(+), 19 deletions(-) diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala index 383241ada..d939248c8 100644 --- a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala +++ b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala @@ -366,6 +366,65 @@ object ParquetField { } } + private val keyField = "key" + private val valueField = "value" + private val mapGroup = "key_value" + implicit def pfMap[T](implicit pf: ParquetField[T]): ParquetField[Map[String, T]] = + new ParquetField[Map[String, T]] { + override def buildSchema(cm: CaseMapper): Type = + Types + .repeatedGroup() + .addField(Schema.rename(pfString.buildSchema(cm), keyField)) + .addField(Schema.rename(pf.schema(cm), valueField)) + .as(LogicalTypeAnnotation.mapType()) + .named(mapGroup) + + override val hasAvroArray: Boolean = pf.hasAvroArray + + override protected def isEmpty(v: Map[String, T]): Boolean = v.isEmpty + + override def fieldDocs(cm: CaseMapper): Map[String, String] = pf.fieldDocs(cm) + + override val typeDoc: Option[String] = None + + override def write(c: RecordConsumer, v: Map[String, T])(cm: CaseMapper): Unit = { + v.foreach { case (k, v) => + c.startGroup() + c.startField(keyField, 0) + pfString.writeGroup(c, k)(cm) + c.endField(keyField, 0) + c.startField(valueField, 1) + pf.writeGroup(c, v)(cm) + c.endField(valueField, 1) + c.endGroup() + } + } + + override def newConverter: TypeConverter[Map[String, T]] = { + val kvConverter = new GroupConverter with TypeConverter.Buffered[(String, T)] { + private val keyConverter = pfString.newConverter + private val valueConverter = pf.newConverter + private val fieldConverters = Array(keyConverter, valueConverter) + + override def isPrimitive: Boolean = false + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def start(): Unit = () + + override def end(): Unit = { + val key = keyConverter.get + val value = valueConverter.get + addValue(key -> value) + } + }.withRepetition(Repetition.REPEATED) + + new TypeConverter.Delegate[(String, T), Map[String, T]](kvConverter) { + override def get: Map[String, T] = inner.get(_.toMap) + } + } + } + // //////////////////////////////////////////////// def logicalType[T](lta: => LogicalTypeAnnotation): LogicalTypeWord[T] = diff --git a/parquet/src/test/scala/magnolify/parquet/ParquetTypeSuite.scala b/parquet/src/test/scala/magnolify/parquet/ParquetTypeSuite.scala index 30cbf29e3..46a9734c4 100644 --- a/parquet/src/test/scala/magnolify/parquet/ParquetTypeSuite.scala +++ b/parquet/src/test/scala/magnolify/parquet/ParquetTypeSuite.scala @@ -88,6 +88,9 @@ class ParquetTypeSuite extends MagnolifySuite { test[ParquetTypes] + test[MapPrimitive] + test[MapNested] + test("AnyVal") { implicit val pt: ParquetType[HasValueClass] = ParquetType[HasValueClass] test[HasValueClass] @@ -193,6 +196,8 @@ class ParquetTypeSuite extends MagnolifySuite { case class Unsafe(c: Char) case class ParquetTypes(b: Byte, s: Short, ba: Array[Byte]) +case class MapPrimitive(strMap: Map[String, Int]) +case class MapNested(m: Map[String, Nested]) case class Decimal(bd: BigDecimal, bdo: Option[BigDecimal]) case class Logical(u: UUID, d: LocalDate) case class Time(i: Instant, dt: LocalDateTime, ot: OffsetTime, t: LocalTime) diff --git a/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala b/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala index 42d68477e..6f68e67b3 100644 --- a/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala +++ b/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala @@ -117,15 +117,26 @@ class ParquetParserSuite extends MagnolifySuite { test[DateTime]("nanos", dateTimeSchema) } + private def kvSchema(valueSchema: Schema) = Record( + None, + None, + None, + List( + Field("key", None, Primitive.String, Required), + Field("value", None, valueSchema, Required) + ) + ) + private val repetitionsSchema = Record( Some("Repetitions"), namespace, None, List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, Primitive.Int, kv._2)) + Field("r", None, Primitive.Int, Required), + Field("o", None, Primitive.Int, Optional), + Field("l", None, Primitive.Int, Repeated), + Field("m", None, kvSchema(Primitive.Int), Repeated) + ) ) test[Repetitions](repetitionsSchema) @@ -137,19 +148,19 @@ class ParquetParserSuite extends MagnolifySuite { private val innerSchema = Record(None, None, None, List(Field("i", None, Primitive.Int, Required))) - - test[Outer]( - Record( - Some("Outer"), - namespace, - None, - List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, innerSchema, kv._2)) + private val outerSchema = Record( + Some("Outer"), + namespace, + None, + List( + Field("r", None, innerSchema, Required), + Field("o", None, innerSchema, Optional), + Field("l", None, innerSchema, Repeated), + Field("m", None, kvSchema(innerSchema), Repeated) ) ) + + test[Outer](outerSchema) } object ParquetParserSuite { @@ -173,9 +184,7 @@ object ParquetParserSuite { case class Decimal(bd: BigDecimal) case class Date(d: LocalDate) case class DateTime(i: Instant, dt: LocalDateTime, ot: OffsetTime, t: LocalTime) - - case class Repetitions(r: Int, o: Option[Int], l: List[Int]) - + case class Repetitions(r: Int, o: Option[Int], l: List[Int], m: Map[String, Int]) case class Inner(i: Int) - case class Outer(r: Inner, o: Option[Inner], l: List[Inner]) + case class Outer(r: Inner, o: Option[Inner], l: List[Inner], m: Map[String, Inner]) }