From 45ab8191ce43af65be933bc80ef7322c8ea5119d Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 18 Dec 2023 18:25:26 +0100 Subject: [PATCH] Generate scala Map type for parquet --- .../scala/magnolify/tools/AvroParser.scala | 61 ++++---- .../magnolify/tools/BigQueryParser.scala | 17 +-- .../scala/magnolify/tools/ParquetParser.scala | 58 +++++--- .../main/scala/magnolify/tools/Schema.scala | 48 +++---- .../scala/magnolify/tools/SchemaPrinter.scala | 78 ++++++---- .../magnolify/tools/AvroParserSuite.scala | 89 +++++------- .../magnolify/tools/BigQueryParserSuite.scala | 48 +++---- .../magnolify/tools/ParquetParserSuite.scala | 80 ++++------- .../magnolify/tools/SchemaPrinterSuite.scala | 133 +++++++----------- 9 files changed, 281 insertions(+), 331 deletions(-) diff --git a/tools/src/main/scala/magnolify/tools/AvroParser.scala b/tools/src/main/scala/magnolify/tools/AvroParser.scala index f4c980712..627d81335 100644 --- a/tools/src/main/scala/magnolify/tools/AvroParser.scala +++ b/tools/src/main/scala/magnolify/tools/AvroParser.scala @@ -28,50 +28,47 @@ object AvroParser extends SchemaParser[avro.Schema] { private def parseRecord(schema: avro.Schema): Record = { val fields = schema.getFields.asScala.iterator.map { f => - val (s, r) = parseSchemaAndRepetition(f.schema()) - Field(f.name(), Option(f.doc()), s, r) + val s = parseSchema(f.schema()) + Record.Field(f.name(), Option(f.doc()), s) }.toList - Record(Option(schema.getName), Option(schema.getNamespace), Option(schema.getDoc), fields) + Record( + Some(schema.getName), + Option(schema.getDoc), + fields + ) } - private def parseEnum(schema: avro.Schema): Enum = - Enum( - Option(schema.getName), - Option(schema.getNamespace), + private def parseEnum(schema: avro.Schema): Primitive.Enum = + Primitive.Enum( + Some(schema.getName), Option(schema.getDoc), schema.getEnumSymbols.asScala.toList ) - private def parseSchemaAndRepetition(schema: avro.Schema): (Schema, Repetition) = - schema.getType match { - case Type.UNION - if schema.getTypes.size() == 2 && - schema.getTypes.asScala.count(_.getType == Type.NULL) == 1 => - val s = schema.getTypes.asScala.find(_.getType != Type.NULL).get - if (s.getType == Type.ARRAY) { - // Nullable array, e.g. ["null", {"type": "array", "items": ...}] - (parseSchema(s.getElementType), Repeated) - } else { - (parseSchema(s), Optional) - } - case Type.ARRAY => - (parseSchema(schema.getElementType), Repeated) - // FIXME: map - case _ => - (parseSchema(schema), Required) - } - private def parseSchema(schema: avro.Schema): Schema = schema.getType match { - // Nested types - case Type.RECORD => parseRecord(schema) - case Type.ENUM => parseEnum(schema) + // Composite types + case Type.RECORD => + parseRecord(schema) + case Type.UNION => + val types = schema.getTypes.asScala + if (types.size != 2 || !types.exists(_.getType == Type.NULL)) { + throw new IllegalArgumentException(s"Unsupported union $schema") + } else { + val s = types.find(_.getType != Type.NULL).get + Optional(parseSchema(s)) + } + case Type.ARRAY => + Repeated(parseSchema(schema.getElementType)) + case Type.MAP => + Mapped(Primitive.String, parseSchema(schema.getValueType)) // Logical types case Type.STRING if isLogical(schema, LogicalTypes.uuid().getName) => Primitive.UUID case Type.BYTES if schema.getLogicalType.isInstanceOf[LogicalTypes.Decimal] => Primitive.BigDecimal - case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] => Primitive.LocalDate + case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] => + Primitive.LocalDate // Millis case Type.LONG if schema.getLogicalType.isInstanceOf[LogicalTypes.TimestampMillis] => @@ -92,9 +89,11 @@ object AvroParser extends SchemaParser[avro.Schema] { Primitive.LocalDateTime // BigQuery sqlType: DATETIME - case Type.STRING if isLogical(schema, "datetime") => Primitive.LocalDateTime + case Type.STRING if isLogical(schema, "datetime") => + Primitive.LocalDateTime // Primitive types + case Type.ENUM => parseEnum(schema) case Type.FIXED => Primitive.Bytes case Type.STRING => Primitive.String case Type.BYTES => Primitive.Bytes diff --git a/tools/src/main/scala/magnolify/tools/BigQueryParser.scala b/tools/src/main/scala/magnolify/tools/BigQueryParser.scala index 7fb41bb0d..549f696b7 100644 --- a/tools/src/main/scala/magnolify/tools/BigQueryParser.scala +++ b/tools/src/main/scala/magnolify/tools/BigQueryParser.scala @@ -26,12 +26,7 @@ object BigQueryParser extends SchemaParser[TableSchema] { private def parseRecord(fields: List[TableFieldSchema]): Record = { val fs = fields.map { f => - val r = f.getMode match { - case "REQUIRED" => Required - case "NULLABLE" => Optional - case "REPEATED" => Repeated - } - val s = f.getType match { + val schema = f.getType match { case "INT64" => Primitive.Long case "FLOAT64" => Primitive.Double case "NUMERIC" => Primitive.BigDecimal @@ -44,8 +39,14 @@ object BigQueryParser extends SchemaParser[TableSchema] { case "DATETIME" => Primitive.LocalDateTime case "STRUCT" => parseRecord(f.getFields.asScala.toList) } - Field(f.getName, Option(f.getDescription), s, r) + + val moddedSchema = f.getMode match { + case "REQUIRED" => schema + case "NULLABLE" => Optional(schema) + case "REPEATED" => Repeated(schema) + } + Record.Field(f.getName, Option(f.getDescription), moddedSchema) } - Record(None, None, None, fs) + Record(None, None, fs) } } diff --git a/tools/src/main/scala/magnolify/tools/ParquetParser.scala b/tools/src/main/scala/magnolify/tools/ParquetParser.scala index ab671ee3b..c6d367b85 100644 --- a/tools/src/main/scala/magnolify/tools/ParquetParser.scala +++ b/tools/src/main/scala/magnolify/tools/ParquetParser.scala @@ -32,33 +32,22 @@ object ParquetParser extends SchemaParser[MessageType] { override def parse(schema: MessageType): Record = { val name = schema.getName val idx = name.lastIndexOf('.') - val n = Some(name.drop(idx + 1)) - val ns = Some(name.take(idx)).filter(_.nonEmpty) - parseGroup(schema.asGroupType()).copy(name = n, ns) + val n = name.drop(idx + 1) + parseRecord(schema.asGroupType()).copy(name = Some(n)) } - private def parseRepetition(repetition: Type.Repetition): Repetition = repetition match { - case Type.Repetition.REQUIRED => Required - case Type.Repetition.OPTIONAL => Optional - case Type.Repetition.REPEATED => Repeated - } + private def putRepetition(repetition: Type.Repetition)(schema: Schema): Schema = + repetition match { + case Type.Repetition.REQUIRED => schema + case Type.Repetition.OPTIONAL => Optional(schema) + case Type.Repetition.REPEATED => Repeated(schema) + } - private def parseGroup(groupType: GroupType): Record = { + private def parseRecord(groupType: GroupType): Record = { val fields = groupType.getFields.asScala.iterator.map { f => - if (f.isPrimitive) { - val schema = parsePrimitive(f.asPrimitiveType()) - Field(f.getName, None, schema, parseRepetition(f.getRepetition)) - } else { - val gt = f.asGroupType() - if (isAvroArray(gt)) { - Field(f.getName, None, parseType(gt.getFields.get(0)), Repeated) - } else { - val schema = parseGroup(gt) - Field(f.getName, None, schema, parseRepetition(f.getRepetition)) - } - } + Record.Field(f.getName, None, parseType(f)) }.toList - Record(None, None, None, fields) + Record(None, None, fields) } private def isAvroArray(groupType: GroupType): Boolean = @@ -66,9 +55,32 @@ object ParquetParser extends SchemaParser[MessageType] { groupType.getFieldCount == 1 && groupType.getFieldName(0) == "array" && groupType.getFields.get(0).isRepetition(Type.Repetition.REPEATED) + private def parseAvroArray(groupType: GroupType): Schema = + parseType(groupType.getFields.get(0)) + private def isMap(groupType: GroupType): Boolean = + groupType.getLogicalTypeAnnotation == LTA.mapType() && + groupType.getFieldCount == 2 && + groupType.isRepetition(Type.Repetition.REPEATED) + + private def parseMap(groupType: GroupType): Schema = { + val keySchema = parseType(groupType.getFields.get(0)) + val valueSchema = parseType(groupType.getFields.get(1)) + Mapped(keySchema, valueSchema) + } private def parseType(tpe: Type): Schema = - if (tpe.isPrimitive) parsePrimitive(tpe.asPrimitiveType()) else parseGroup(tpe.asGroupType()) + if (tpe.isPrimitive) { + putRepetition(tpe.getRepetition)(parsePrimitive(tpe.asPrimitiveType())) + } else { + val groupType = tpe.asGroupType() + if (isAvroArray(groupType)) { + parseAvroArray(groupType) + } else if (isMap(groupType)) { + parseMap(groupType) + } else { + putRepetition(tpe.getRepetition)(parseRecord(groupType)) + } + } private def parsePrimitive(primitiveType: PrimitiveType): Primitive = { val ptn = primitiveType.getPrimitiveTypeName diff --git a/tools/src/main/scala/magnolify/tools/Schema.scala b/tools/src/main/scala/magnolify/tools/Schema.scala index f1fbbacd0..407eb0c4f 100644 --- a/tools/src/main/scala/magnolify/tools/Schema.scala +++ b/tools/src/main/scala/magnolify/tools/Schema.scala @@ -19,37 +19,35 @@ package magnolify.tools sealed trait Schema sealed trait Primitive extends Schema - -sealed trait Nested extends Schema - -sealed trait Repetition - -case object Required extends Repetition -case object Optional extends Repetition -case object Repeated extends Repetition - -case class Record( +sealed trait Composite extends Schema +final case class Record( name: Option[String], - namespace: Option[String], +// namespace: Option[String], // TODO respect namespace doc: Option[String], - fields: List[Field] -) extends Nested + fields: List[Record.Field] +) extends Composite -case class Field( - name: String, - doc: Option[String], - schema: Schema, - repetition: Repetition -) +object Record { + case class Field( + name: String, + doc: Option[String], + schema: Schema + ) -case class Enum( - name: Option[String], - namespace: Option[String], - doc: Option[String], - values: List[String] -) extends Nested +} + +case class Optional(schema: Schema) extends Composite +case class Repeated(schema: Schema) extends Composite +case class Mapped(keySchema: Schema, valueSchema: Schema) extends Composite object Primitive { + final case class Enum( + name: Option[String], +// namespace: Option[String], + doc: Option[String], + values: List[String] + ) extends Primitive + case object Null extends Primitive case object Boolean extends Primitive case object Char extends Primitive diff --git a/tools/src/main/scala/magnolify/tools/SchemaPrinter.scala b/tools/src/main/scala/magnolify/tools/SchemaPrinter.scala index f5bb08dcd..b75d9c1ad 100644 --- a/tools/src/main/scala/magnolify/tools/SchemaPrinter.scala +++ b/tools/src/main/scala/magnolify/tools/SchemaPrinter.scala @@ -20,36 +20,44 @@ import com.google.common.base.CaseFormat import org.typelevel.paiges._ object SchemaPrinter { - def print(schema: Record, width: Int = 100): String = renderRecord(schema).renderTrim(width) + private case class RenderContext(field: String, owner: Option[String]) + + private val RootContext = RenderContext("root", None) + + def print(schema: Record, width: Int = 100): String = + renderRecord(RootContext)(schema).renderTrim(width) + + private def renderRecord(ctx: RenderContext)(schema: Record): Doc = { + val name = schema.name.getOrElse(toUpperCamel(ctx.field)) - private def renderRecord(schema: Record): Doc = { - val name = schema.name.get val header = Doc.text("case class") + Doc.space + Doc.text(name) + Doc.char('(') + val body = Doc.intercalate( + Doc.char(',') + Doc.lineOrSpace, + schema.fields.map { f => + val fieldCtx = RenderContext(f.name, Some(name)) + val param = quoteIdentifier(f.name) + val tpe = renderType(fieldCtx)(f.schema) + Doc.text(param) + Doc.char(':') + Doc.space + tpe + } + ) val footer = Doc.char(')') - val body = - Doc.intercalate( - Doc.char(',') + Doc.lineOrSpace, - schema.fields.map(renderField(name, _)) - ) val caseClass = body.tightBracketBy(header + Doc.lineOrEmpty, Doc.lineOrEmpty + footer) + val companion = renderCompanion(name, schema.fields) caseClass + companion } - private def renderCompanion(name: String, fields: List[Field]): Doc = { + private def renderCompanion(name: String, fields: List[Record.Field]): Doc = { val header = Doc.text("object") + Doc.space + Doc.text(name) + Doc.space + Doc.char('{') val footer = Doc.char('}') val nestedFields = fields .flatMap { f => f.schema match { case record: Record => - val n = record.name.orElse(Some(toUpperCamel(f.name))) - Some(n.get -> renderRecord(record.copy(name = n))) - case enum: Enum => - val n = enum.name.orElse(Some(toUpperCamel(f.name))) - Some(n.get -> renderEnum(enum.copy(name = n))) - case _ => - None + Some(record.name -> renderRecord(RenderContext(f.name, Some(name)))(record)) + case enum: Primitive.Enum => + Some(enum.name -> renderEnum(RenderContext(f.name, Some(name)))(enum)) + case _ => None } } .groupBy(_._1) @@ -71,8 +79,9 @@ object SchemaPrinter { } } - private def renderEnum(schema: Enum): Doc = { - val header = Doc.text("object") + Doc.space + Doc.text(schema.name.get) + Doc.space + + private def renderEnum(ctx: RenderContext)(schema: Primitive.Enum): Doc = { + val name = schema.name.getOrElse(toUpperCamel(ctx.field)) + val header = Doc.text("object") + Doc.space + Doc.text(name) + Doc.space + Doc.text("extends") + Doc.space + Doc.text("Enumeration") + Doc.space + Doc.char('{') val footer = Doc.char('}') @@ -86,19 +95,26 @@ object SchemaPrinter { nested(header, body, footer) } + private def renderOwnerPrefix(ctx: RenderContext): Doc = + ctx.owner.fold(Doc.empty)(o => Doc.text(o) + Doc.char('.')) - private def renderField(name: String, field: Field): Doc = { - val rawType = field.schema match { - case p: Primitive => Doc.text(p.toString) - case r: Record => Doc.text(name + "." + r.name.getOrElse(toUpperCamel(field.name))) - case e: Enum => Doc.text(name + "." + e.name.getOrElse(toUpperCamel(field.name))) - } - val tpe = field.repetition match { - case Required => rawType - case Optional => Doc.text("Option") + Doc.char('[') + rawType + Doc.char(']') - case Repeated => Doc.text("List") + Doc.char('[') + rawType + Doc.char(']') - } - Doc.text(quoteIdentifier(field.name)) + Doc.char(':') + Doc.space + tpe + private def renderType(ctx: RenderContext)(s: Schema): Doc = s match { + case Optional(s) => + Doc.text("Option") + Doc.char('[') + renderType(ctx)(s) + Doc.char(']') + case Repeated(s) => + Doc.text("List") + Doc.char('[') + renderType(ctx)(s) + Doc.char(']') + case Mapped(k, v) => + val keyType = renderType(ctx)(k) + val valueType = renderType(ctx)(v) + Doc.text("Map") + Doc.char('[') + keyType + Doc.char(',') + Doc.space + valueType + Doc.char( + ']' + ) + case Record(name, _, _) => + renderOwnerPrefix(ctx) + Doc.text(name.getOrElse(toUpperCamel(ctx.field))) + case Primitive.Enum(name, _, _) => + renderOwnerPrefix(ctx) + Doc.text(name.getOrElse(toUpperCamel(ctx.field))) + case p: Primitive => + Doc.text(p.toString) } private def nested(header: Doc, body: Doc, footer: Doc): Doc = @@ -113,7 +129,7 @@ object SchemaPrinter { } } - private def toUpperCamel(name: String): String = { + private[tools] def toUpperCamel(name: String): String = { var allLower = true var allUpper = true var hasHyphen = false diff --git a/tools/src/test/scala/magnolify/tools/AvroParserSuite.scala b/tools/src/test/scala/magnolify/tools/AvroParserSuite.scala index dd71b1ac3..f14174b30 100644 --- a/tools/src/test/scala/magnolify/tools/AvroParserSuite.scala +++ b/tools/src/test/scala/magnolify/tools/AvroParserSuite.scala @@ -36,43 +36,29 @@ class AvroParserSuite extends MagnolifySuite { } } - private val namespace = Some(this.getClass.getCanonicalName) - test[Primitives]( Record( Some("Primitives"), - namespace, None, List( - "b" -> Primitive.Boolean, - "i" -> Primitive.Int, - "l" -> Primitive.Long, - "f" -> Primitive.Float, - "d" -> Primitive.Double, - "ba" -> Primitive.Bytes, - "s" -> Primitive.String, - "n" -> Primitive.Null - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("b", None, Primitive.Boolean), + Record.Field("i", None, Primitive.Int), + Record.Field("l", None, Primitive.Long), + Record.Field("f", None, Primitive.Float), + Record.Field("d", None, Primitive.Double), + Record.Field("ba", None, Primitive.Bytes), + Record.Field("s", None, Primitive.String), + Record.Field("n", None, Primitive.Null) + ) ) ) test[Enums]( Record( Some("Enums"), - namespace, None, List( - Field( - "e", - None, - Enum( - Some("Color"), - namespace, - None, - List("Red", "Green", "Blue") - ), - Required - ) + Record.Field("e", None, Primitive.Enum(Some("Color"), None, List("Red", "Green", "Blue"))) ) ) ) @@ -82,11 +68,10 @@ class AvroParserSuite extends MagnolifySuite { test[Logical]( Record( Some("Logical"), - namespace, None, List( - Field("bd", None, Primitive.BigDecimal, Required), - Field("u", None, Primitive.UUID, Required) + Record.Field("bd", None, Primitive.BigDecimal), + Record.Field("u", None, Primitive.UUID) ) ) ) @@ -95,21 +80,19 @@ class AvroParserSuite extends MagnolifySuite { test[Date]( Record( Some("Date"), - namespace, None, - List(Field("d", None, Primitive.LocalDate, Required)) + List(Record.Field("d", None, Primitive.LocalDate)) ) ) private val dateTimeSchema = Record( Some("DateTime"), - namespace, None, List( - "i" -> Primitive.Instant, - "dt" -> Primitive.LocalDateTime, - "t" -> Primitive.LocalTime - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("i", None, Primitive.Instant), + Record.Field("dt", None, Primitive.LocalDateTime), + Record.Field("t", None, Primitive.LocalTime) + ) ) { @@ -127,41 +110,45 @@ class AvroParserSuite extends MagnolifySuite { test[DateTime]("BigQuery", dateTimeSchema) } - test[Repetitions]( + test[Composite]( Record( - Some("Repetitions"), - namespace, + Some("Composite"), None, List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, Primitive.Int, kv._2)) + Record.Field("o", None, Optional(Primitive.Int)), + Record.Field("l", None, Repeated(Primitive.Int)), + Record.Field("m", None, Mapped(Primitive.String, Primitive.Int)) + ) ) ) test[NullableArray]( - Record(Some("NullableArray"), namespace, None, List(Field("a", None, Primitive.Int, Repeated))) + Record( + Some("NullableArray"), + None, + List( + Record.Field("a", None, Optional(Repeated(Primitive.Int))) + ) + ) ) private val innerSchema = Record( Some("Inner"), - namespace, None, - List(Field("i", None, Primitive.Int, Required)) + List(Record.Field("i", None, Primitive.Int)) ) test[Outer]( Record( Some("Outer"), - namespace, None, List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, innerSchema, kv._2)) + Record.Field("r", None, innerSchema), + Record.Field("o", None, Optional(innerSchema)), + Record.Field("l", None, Repeated(innerSchema)), + Record.Field("m", None, Mapped(Primitive.String, innerSchema)) + ) ) ) } @@ -188,9 +175,9 @@ object AvroParserSuite { case class Date(d: LocalDate) case class DateTime(i: Instant, dt: LocalDateTime, t: LocalTime) - case class Repetitions(r: Int, o: Option[Int], l: List[Int]) + case class Composite(o: Option[Int], l: List[Int], m: Map[String, Int]) case class NullableArray(a: Option[List[Int]]) case class Inner(i: Int) - case class Outer(r: Inner, o: Option[Inner], l: List[Inner]) + case class Outer(r: Inner, o: Option[Inner], l: List[Inner], m: Map[String, Inner]) } diff --git a/tools/src/test/scala/magnolify/tools/BigQueryParserSuite.scala b/tools/src/test/scala/magnolify/tools/BigQueryParserSuite.scala index 9cdb62451..b1c3d93cb 100644 --- a/tools/src/test/scala/magnolify/tools/BigQueryParserSuite.scala +++ b/tools/src/test/scala/magnolify/tools/BigQueryParserSuite.scala @@ -39,44 +39,40 @@ class BigQueryParserSuite extends MagnolifySuite { test[Primitives]( Record( - None, None, None, List( - "b" -> Primitive.Boolean, - "l" -> Primitive.Long, - "d" -> Primitive.Double, - "ba" -> Primitive.Bytes, - "s" -> Primitive.String, - "bd" -> Primitive.BigDecimal - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("b", None, Primitive.Boolean), + Record.Field("l", None, Primitive.Long), + Record.Field("d", None, Primitive.Double), + Record.Field("ba", None, Primitive.Bytes), + Record.Field("s", None, Primitive.String), + Record.Field("bd", None, Primitive.BigDecimal) + ) ) ) test[DateTime]( Record( - None, None, None, List( - "i" -> Primitive.Instant, - "dt" -> Primitive.LocalDateTime, - "d" -> Primitive.LocalDate, - "t" -> Primitive.LocalTime - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("i", None, Primitive.Instant), + Record.Field("dt", None, Primitive.LocalDateTime), + Record.Field("d", None, Primitive.LocalDate), + Record.Field("t", None, Primitive.LocalTime) + ) ) ) test[Repetitions]( Record( - None, None, None, List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, Primitive.Long, kv._2)) + Record.Field("o", None, Optional(Primitive.Long)), + Record.Field("l", None, Repeated(Primitive.Long)) + ) ) ) @@ -84,20 +80,18 @@ class BigQueryParserSuite extends MagnolifySuite { Record( None, None, - None, - List(Field("l", None, Primitive.Long, Required)) + List(Record.Field("l", None, Primitive.Long)) ) test[Outer]( Record( - None, None, None, List( - "r" -> Required, - "o" -> Optional, - "l" -> Repeated - ).map(kv => Field(kv._1, None, innerSchema, kv._2)) + Record.Field("r", None, innerSchema), + Record.Field("o", None, Optional(innerSchema)), + Record.Field("l", None, Repeated(innerSchema)) + ) ) ) } @@ -107,7 +101,7 @@ object BigQueryParserSuite { case class DateTime(i: Instant, dt: LocalDateTime, d: LocalDate, t: LocalTime) - case class Repetitions(r: Long, o: Option[Long], l: List[Long]) + case class Repetitions(o: Option[Long], l: List[Long]) case class Inner(l: Long) case class Outer(r: Inner, o: Option[Inner], l: List[Inner]) diff --git a/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala b/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala index 6f68e67b3..e4505ad8f 100644 --- a/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala +++ b/tools/src/test/scala/magnolify/tools/ParquetParserSuite.scala @@ -37,33 +37,29 @@ class ParquetParserSuite extends MagnolifySuite { } } - private val namespace = Some(this.getClass.getCanonicalName) - test[Primitives]( Record( Some("Primitives"), - namespace, None, List( - "b" -> Primitive.Boolean, - "i8" -> Primitive.Byte, - "i16" -> Primitive.Short, - "i32" -> Primitive.Int, - "i64" -> Primitive.Long, - "f" -> Primitive.Float, - "d" -> Primitive.Double, - "ba" -> Primitive.Bytes, - "s" -> Primitive.String, - "e" -> Primitive.String - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("b", None, Primitive.Boolean), + Record.Field("i8", None, Primitive.Byte), + Record.Field("i16", None, Primitive.Short), + Record.Field("i32", None, Primitive.Int), + Record.Field("i64", None, Primitive.Long), + Record.Field("f", None, Primitive.Float), + Record.Field("d", None, Primitive.Double), + Record.Field("ba", None, Primitive.Bytes), + Record.Field("s", None, Primitive.String), + Record.Field("e", None, Primitive.String) + ) ) ) private val decimalSchema = Record( Some("Decimal"), - namespace, None, - List(Field("bd", None, Primitive.BigDecimal, Required)) + List(Record.Field("bd", None, Primitive.BigDecimal)) ) { @@ -87,19 +83,18 @@ class ParquetParserSuite extends MagnolifySuite { } test[Date]( - Record(Some("Date"), namespace, None, List(Field("d", None, Primitive.LocalDate, Required))) + Record(Some("Date"), None, List(Record.Field("d", None, Primitive.LocalDate))) ) private val dateTimeSchema = Record( Some("DateTime"), - namespace, None, List( - "i" -> Primitive.Instant, - "dt" -> Primitive.LocalDateTime, - "ot" -> Primitive.OffsetTime, - "t" -> Primitive.LocalTime - ).map(kv => Field(kv._1, None, kv._2, Required)) + Record.Field("i", None, Primitive.Instant), + Record.Field("dt", None, Primitive.LocalDateTime), + Record.Field("ot", None, Primitive.OffsetTime), + Record.Field("t", None, Primitive.LocalTime) + ) ) { @@ -117,46 +112,33 @@ class ParquetParserSuite extends MagnolifySuite { test[DateTime]("nanos", dateTimeSchema) } - private def kvSchema(valueSchema: Schema) = Record( - None, - None, - None, - List( - Field("key", None, Primitive.String, Required), - Field("value", None, valueSchema, Required) - ) - ) - - private val repetitionsSchema = Record( - Some("Repetitions"), - namespace, + private val compositeSchema = Record( + Some("Composite"), None, List( - Field("r", None, Primitive.Int, Required), - Field("o", None, Primitive.Int, Optional), - Field("l", None, Primitive.Int, Repeated), - Field("m", None, kvSchema(Primitive.Int), Repeated) + Record.Field("o", None, Optional(Primitive.Int)), + Record.Field("l", None, Repeated(Primitive.Int)), + Record.Field("m", None, Mapped(Primitive.String, Primitive.Int)) ) ) - test[Repetitions](repetitionsSchema) + test[Composite](compositeSchema) { import magnolify.parquet.ParquetArray.AvroCompat._ - test[Repetitions]("Avro", repetitionsSchema) + test[Composite]("Avro", compositeSchema) } private val innerSchema = - Record(None, None, None, List(Field("i", None, Primitive.Int, Required))) + Record(None, None, List(Record.Field("i", None, Primitive.Int))) private val outerSchema = Record( Some("Outer"), - namespace, None, List( - Field("r", None, innerSchema, Required), - Field("o", None, innerSchema, Optional), - Field("l", None, innerSchema, Repeated), - Field("m", None, kvSchema(innerSchema), Repeated) + Record.Field("r", None, innerSchema), + Record.Field("o", None, Optional(innerSchema)), + Record.Field("l", None, Repeated(innerSchema)), + Record.Field("m", None, Mapped(Primitive.String, innerSchema)) ) ) @@ -184,7 +166,7 @@ object ParquetParserSuite { case class Decimal(bd: BigDecimal) case class Date(d: LocalDate) case class DateTime(i: Instant, dt: LocalDateTime, ot: OffsetTime, t: LocalTime) - case class Repetitions(r: Int, o: Option[Int], l: List[Int], m: Map[String, Int]) + case class Composite(o: Option[Int], l: List[Int], m: Map[String, Int]) case class Inner(i: Int) case class Outer(r: Inner, o: Option[Inner], l: List[Inner], m: Map[String, Inner]) } diff --git a/tools/src/test/scala/magnolify/tools/SchemaPrinterSuite.scala b/tools/src/test/scala/magnolify/tools/SchemaPrinterSuite.scala index 0bc3efd01..1f4a29bfa 100644 --- a/tools/src/test/scala/magnolify/tools/SchemaPrinterSuite.scala +++ b/tools/src/test/scala/magnolify/tools/SchemaPrinterSuite.scala @@ -16,12 +16,19 @@ package magnolify.tools -import magnolify.test._ +import magnolify.test.* -class SchemaPrinterSuite extends MagnolifySuite { +class SchemaPrinterSuite extends munit.ScalaCheckSuite { private def test(schema: Record, code: String): Unit = assertEquals(SchemaPrinter.print(schema).trim, code.trim) + test("Root") { + test( + Record(None, None, List(Record.Field("f", None, Primitive.Int))), + "case class Root(f: Int)" + ) + } + test("Primitive") { List( Primitive.Null, @@ -45,29 +52,29 @@ class SchemaPrinterSuite extends MagnolifySuite { Primitive.UUID ).foreach { p => test( - Record(Some("Primitive"), None, None, List(Field("f", None, p, Required))), + Record(Some("Primitive"), None, List(Record.Field("f", None, p))), s"case class Primitive(f: $p)" ) } } - test("Repetition") { + test("Composite") { List( - Required -> "%s", - Optional -> "Option[%s]", - Repeated -> "List[%s]" - ).foreach { case (r, f) => + Optional(Primitive.Int) -> "Option[Int]", + Repeated(Primitive.Int) -> "List[Int]", + Mapped(Primitive.String, Primitive.Int) -> "Map[String, Int]" + ).foreach { case (schema, expected) => test( - Record(Some("Repetition"), None, None, List(Field("f", None, Primitive.Int, r))), - s"case class Repetition(f: ${f.format("Int")})" + Record(Some("Composite"), None, List(Record.Field("f", None, schema))), + s"case class Composite(f: $expected)" ) } } test("Enum") { - val anonymous = Enum(None, None, None, List("Red", "Green", "Blue")) + val anonymous = Primitive.Enum(None, None, List("Red", "Green", "Blue")) test( - Record(Some("Enum"), None, None, List(Field("color_enum", None, anonymous, Required))), + Record(Some("Enum"), None, List(Record.Field("color_enum", None, anonymous))), """case class Enum(color_enum: Enum.ColorEnum) | |object Enum { @@ -81,7 +88,7 @@ class SchemaPrinterSuite extends MagnolifySuite { val named = anonymous.copy(name = Some("Color")) test( - Record(Some("Enum"), None, None, List(Field("color_enum", None, named, Required))), + Record(Some("Enum"), None, List(Record.Field("color_enum", None, named))), """case class Enum(color_enum: Enum.Color) | |object Enum { @@ -95,9 +102,9 @@ class SchemaPrinterSuite extends MagnolifySuite { } test("Record") { - val anonymous = Record(None, None, None, List(Field("f", None, Primitive.Int, Required))) + val anonymous = Record(None, None, List(Record.Field("f", None, Primitive.Int))) test( - Record(Some("Record"), None, None, List(Field("inner_record", None, anonymous, Required))), + Record(Some("Record"), None, List(Record.Field("inner_record", None, anonymous))), """case class Record(inner_record: Record.InnerRecord) | |object Record { @@ -108,7 +115,7 @@ class SchemaPrinterSuite extends MagnolifySuite { val named = anonymous.copy(name = Some("Inner")) test( - Record(Some("Record"), None, None, List(Field("inner_record", None, named, Required))), + Record(Some("Record"), None, List(Record.Field("inner_record", None, named))), """case class Record(inner_record: Record.Inner) | |object Record { @@ -118,87 +125,24 @@ class SchemaPrinterSuite extends MagnolifySuite { ) } - test("UpperCamel") { - val inner = Record(None, None, None, List(Field("f", None, Primitive.Int, Required))) - val field = Field(null, None, inner, Required) - - // lower_underscore - test( - Record(Some("Outer"), None, None, List(field.copy(name = "inner_record"))), - """case class Outer(inner_record: Outer.InnerRecord) - | - |object Outer { - | case class InnerRecord(f: Int) - |} - |""".stripMargin - ) - - // UPPER_UNDERSCORE - test( - Record(Some("Outer"), None, None, List(field.copy(name = "INNER_RECORD"))), - """case class Outer(INNER_RECORD: Outer.InnerRecord) - | - |object Outer { - | case class InnerRecord(f: Int) - |} - |""".stripMargin - ) - - // lowerCamel - test( - Record(Some("Outer"), None, None, List(field.copy(name = "innerRecord"))), - """case class Outer(innerRecord: Outer.InnerRecord) - | - |object Outer { - | case class InnerRecord(f: Int) - |} - |""".stripMargin - ) - - // UpperCamel - test( - Record(Some("Outer"), None, None, List(field.copy(name = "InnerRecord"))), - """case class Outer(InnerRecord: Outer.InnerRecord) - | - |object Outer { - | case class InnerRecord(f: Int) - |} - |""".stripMargin - ) - - // lower-hyphen - test( - Record(Some("Outer"), None, None, List(field.copy(name = "inner-record"))), - """case class Outer(`inner-record`: Outer.InnerRecord) - | - |object Outer { - | case class InnerRecord(f: Int) - |} - |""".stripMargin - ) - } - test("Deduplication") { - val level3 = - Record(Some("Level3"), None, None, List(Field("i3", None, Primitive.Int, Required))) + val level3 = Record(Some("Level3"), None, List(Record.Field("i3", None, Primitive.Int))) val level2 = Record( Some("Level2"), None, - None, List( - Field("r2", None, level3, Required), - Field("o2", None, level3, Optional), - Field("l2", None, level3, Repeated) + Record.Field("r2", None, level3), + Record.Field("o2", None, Optional(level3)), + Record.Field("l2", None, Repeated(level3)) ) ) val level1 = Record( Some("Level1"), None, - None, List( - Field("r1", None, level2, Required), - Field("o1", None, level2, Optional), - Field("l1", None, level2, Repeated) + Record.Field("r1", None, level2), + Record.Field("o1", None, Optional(level2)), + Record.Field("l1", None, Repeated(level2)) ) ) @@ -216,4 +160,21 @@ class SchemaPrinterSuite extends MagnolifySuite { |""".stripMargin ) } + + test("UpperCamel") { + // lower_underscore + assertEquals(SchemaPrinter.toUpperCamel("inner_record"), "InnerRecord") + + // UPPER_UNDERSCORE + assertEquals(SchemaPrinter.toUpperCamel("INNER_RECORD"), "InnerRecord") + + // lowerCamel + assertEquals(SchemaPrinter.toUpperCamel("innerRecord"), "InnerRecord") + + // UpperCamel + assertEquals(SchemaPrinter.toUpperCamel("InnerRecord"), "InnerRecord") + + // lower-hyphen + assertEquals(SchemaPrinter.toUpperCamel("inner-record"), "InnerRecord") + } }