Skip to content

Commit

Permalink
Generate scala Map type for parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
Michel Davit committed Dec 18, 2023
1 parent 9c7c986 commit 45ab819
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 331 deletions.
61 changes: 30 additions & 31 deletions tools/src/main/scala/magnolify/tools/AvroParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,47 @@ object AvroParser extends SchemaParser[avro.Schema] {

private def parseRecord(schema: avro.Schema): Record = {
val fields = schema.getFields.asScala.iterator.map { f =>
val (s, r) = parseSchemaAndRepetition(f.schema())
Field(f.name(), Option(f.doc()), s, r)
val s = parseSchema(f.schema())
Record.Field(f.name(), Option(f.doc()), s)
}.toList
Record(Option(schema.getName), Option(schema.getNamespace), Option(schema.getDoc), fields)
Record(
Some(schema.getName),
Option(schema.getDoc),
fields
)
}

private def parseEnum(schema: avro.Schema): Enum =
Enum(
Option(schema.getName),
Option(schema.getNamespace),
private def parseEnum(schema: avro.Schema): Primitive.Enum =
Primitive.Enum(
Some(schema.getName),
Option(schema.getDoc),
schema.getEnumSymbols.asScala.toList
)

private def parseSchemaAndRepetition(schema: avro.Schema): (Schema, Repetition) =
schema.getType match {
case Type.UNION
if schema.getTypes.size() == 2 &&
schema.getTypes.asScala.count(_.getType == Type.NULL) == 1 =>
val s = schema.getTypes.asScala.find(_.getType != Type.NULL).get
if (s.getType == Type.ARRAY) {
// Nullable array, e.g. ["null", {"type": "array", "items": ...}]
(parseSchema(s.getElementType), Repeated)
} else {
(parseSchema(s), Optional)
}
case Type.ARRAY =>
(parseSchema(schema.getElementType), Repeated)
// FIXME: map
case _ =>
(parseSchema(schema), Required)
}

private def parseSchema(schema: avro.Schema): Schema = schema.getType match {
// Nested types
case Type.RECORD => parseRecord(schema)
case Type.ENUM => parseEnum(schema)
// Composite types
case Type.RECORD =>
parseRecord(schema)
case Type.UNION =>
val types = schema.getTypes.asScala
if (types.size != 2 || !types.exists(_.getType == Type.NULL)) {
throw new IllegalArgumentException(s"Unsupported union $schema")

Check warning on line 55 in tools/src/main/scala/magnolify/tools/AvroParser.scala

View check run for this annotation

Codecov / codecov/patch

tools/src/main/scala/magnolify/tools/AvroParser.scala#L55

Added line #L55 was not covered by tests
} else {
val s = types.find(_.getType != Type.NULL).get
Optional(parseSchema(s))
}
case Type.ARRAY =>
Repeated(parseSchema(schema.getElementType))
case Type.MAP =>
Mapped(Primitive.String, parseSchema(schema.getValueType))

// Logical types
case Type.STRING if isLogical(schema, LogicalTypes.uuid().getName) =>
Primitive.UUID
case Type.BYTES if schema.getLogicalType.isInstanceOf[LogicalTypes.Decimal] =>
Primitive.BigDecimal
case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] => Primitive.LocalDate
case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] =>
Primitive.LocalDate

// Millis
case Type.LONG if schema.getLogicalType.isInstanceOf[LogicalTypes.TimestampMillis] =>
Expand All @@ -92,9 +89,11 @@ object AvroParser extends SchemaParser[avro.Schema] {
Primitive.LocalDateTime

// BigQuery sqlType: DATETIME
case Type.STRING if isLogical(schema, "datetime") => Primitive.LocalDateTime
case Type.STRING if isLogical(schema, "datetime") =>
Primitive.LocalDateTime

// Primitive types
case Type.ENUM => parseEnum(schema)
case Type.FIXED => Primitive.Bytes
case Type.STRING => Primitive.String
case Type.BYTES => Primitive.Bytes
Expand Down
17 changes: 9 additions & 8 deletions tools/src/main/scala/magnolify/tools/BigQueryParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ object BigQueryParser extends SchemaParser[TableSchema] {

private def parseRecord(fields: List[TableFieldSchema]): Record = {
val fs = fields.map { f =>
val r = f.getMode match {
case "REQUIRED" => Required
case "NULLABLE" => Optional
case "REPEATED" => Repeated
}
val s = f.getType match {
val schema = f.getType match {
case "INT64" => Primitive.Long
case "FLOAT64" => Primitive.Double
case "NUMERIC" => Primitive.BigDecimal
Expand All @@ -44,8 +39,14 @@ object BigQueryParser extends SchemaParser[TableSchema] {
case "DATETIME" => Primitive.LocalDateTime
case "STRUCT" => parseRecord(f.getFields.asScala.toList)
}
Field(f.getName, Option(f.getDescription), s, r)

val moddedSchema = f.getMode match {
case "REQUIRED" => schema
case "NULLABLE" => Optional(schema)
case "REPEATED" => Repeated(schema)
}
Record.Field(f.getName, Option(f.getDescription), moddedSchema)
}
Record(None, None, None, fs)
Record(None, None, fs)
}
}
58 changes: 35 additions & 23 deletions tools/src/main/scala/magnolify/tools/ParquetParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,55 @@ object ParquetParser extends SchemaParser[MessageType] {
override def parse(schema: MessageType): Record = {
val name = schema.getName
val idx = name.lastIndexOf('.')
val n = Some(name.drop(idx + 1))
val ns = Some(name.take(idx)).filter(_.nonEmpty)
parseGroup(schema.asGroupType()).copy(name = n, ns)
val n = name.drop(idx + 1)
parseRecord(schema.asGroupType()).copy(name = Some(n))
}

private def parseRepetition(repetition: Type.Repetition): Repetition = repetition match {
case Type.Repetition.REQUIRED => Required
case Type.Repetition.OPTIONAL => Optional
case Type.Repetition.REPEATED => Repeated
}
private def putRepetition(repetition: Type.Repetition)(schema: Schema): Schema =
repetition match {
case Type.Repetition.REQUIRED => schema
case Type.Repetition.OPTIONAL => Optional(schema)
case Type.Repetition.REPEATED => Repeated(schema)
}

private def parseGroup(groupType: GroupType): Record = {
private def parseRecord(groupType: GroupType): Record = {
val fields = groupType.getFields.asScala.iterator.map { f =>
if (f.isPrimitive) {
val schema = parsePrimitive(f.asPrimitiveType())
Field(f.getName, None, schema, parseRepetition(f.getRepetition))
} else {
val gt = f.asGroupType()
if (isAvroArray(gt)) {
Field(f.getName, None, parseType(gt.getFields.get(0)), Repeated)
} else {
val schema = parseGroup(gt)
Field(f.getName, None, schema, parseRepetition(f.getRepetition))
}
}
Record.Field(f.getName, None, parseType(f))
}.toList
Record(None, None, None, fields)
Record(None, None, fields)
}

private def isAvroArray(groupType: GroupType): Boolean =
groupType.getLogicalTypeAnnotation == LTA.listType() &&
groupType.getFieldCount == 1 &&
groupType.getFieldName(0) == "array" &&
groupType.getFields.get(0).isRepetition(Type.Repetition.REPEATED)
private def parseAvroArray(groupType: GroupType): Schema =
parseType(groupType.getFields.get(0))
private def isMap(groupType: GroupType): Boolean =
groupType.getLogicalTypeAnnotation == LTA.mapType() &&
groupType.getFieldCount == 2 &&
groupType.isRepetition(Type.Repetition.REPEATED)

private def parseMap(groupType: GroupType): Schema = {
val keySchema = parseType(groupType.getFields.get(0))
val valueSchema = parseType(groupType.getFields.get(1))
Mapped(keySchema, valueSchema)
}

private def parseType(tpe: Type): Schema =
if (tpe.isPrimitive) parsePrimitive(tpe.asPrimitiveType()) else parseGroup(tpe.asGroupType())
if (tpe.isPrimitive) {
putRepetition(tpe.getRepetition)(parsePrimitive(tpe.asPrimitiveType()))
} else {
val groupType = tpe.asGroupType()
if (isAvroArray(groupType)) {
parseAvroArray(groupType)
} else if (isMap(groupType)) {
parseMap(groupType)
} else {
putRepetition(tpe.getRepetition)(parseRecord(groupType))
}
}

private def parsePrimitive(primitiveType: PrimitiveType): Primitive = {
val ptn = primitiveType.getPrimitiveTypeName
Expand Down
48 changes: 23 additions & 25 deletions tools/src/main/scala/magnolify/tools/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,35 @@ package magnolify.tools
sealed trait Schema

sealed trait Primitive extends Schema

sealed trait Nested extends Schema

sealed trait Repetition

case object Required extends Repetition
case object Optional extends Repetition
case object Repeated extends Repetition

case class Record(
sealed trait Composite extends Schema
final case class Record(
name: Option[String],
namespace: Option[String],
// namespace: Option[String], // TODO respect namespace
doc: Option[String],
fields: List[Field]
) extends Nested
fields: List[Record.Field]
) extends Composite

case class Field(
name: String,
doc: Option[String],
schema: Schema,
repetition: Repetition
)
object Record {
case class Field(
name: String,
doc: Option[String],
schema: Schema
)

case class Enum(
name: Option[String],
namespace: Option[String],
doc: Option[String],
values: List[String]
) extends Nested
}

case class Optional(schema: Schema) extends Composite
case class Repeated(schema: Schema) extends Composite
case class Mapped(keySchema: Schema, valueSchema: Schema) extends Composite

object Primitive {
final case class Enum(
name: Option[String],
// namespace: Option[String],
doc: Option[String],
values: List[String]
) extends Primitive

case object Null extends Primitive
case object Boolean extends Primitive
case object Char extends Primitive
Expand Down
78 changes: 47 additions & 31 deletions tools/src/main/scala/magnolify/tools/SchemaPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,44 @@ import com.google.common.base.CaseFormat
import org.typelevel.paiges._

object SchemaPrinter {
def print(schema: Record, width: Int = 100): String = renderRecord(schema).renderTrim(width)
private case class RenderContext(field: String, owner: Option[String])

private val RootContext = RenderContext("root", None)

def print(schema: Record, width: Int = 100): String =
renderRecord(RootContext)(schema).renderTrim(width)

private def renderRecord(ctx: RenderContext)(schema: Record): Doc = {
val name = schema.name.getOrElse(toUpperCamel(ctx.field))

private def renderRecord(schema: Record): Doc = {
val name = schema.name.get
val header = Doc.text("case class") + Doc.space + Doc.text(name) + Doc.char('(')
val body = Doc.intercalate(
Doc.char(',') + Doc.lineOrSpace,
schema.fields.map { f =>
val fieldCtx = RenderContext(f.name, Some(name))
val param = quoteIdentifier(f.name)
val tpe = renderType(fieldCtx)(f.schema)
Doc.text(param) + Doc.char(':') + Doc.space + tpe
}
)
val footer = Doc.char(')')
val body =
Doc.intercalate(
Doc.char(',') + Doc.lineOrSpace,
schema.fields.map(renderField(name, _))
)
val caseClass = body.tightBracketBy(header + Doc.lineOrEmpty, Doc.lineOrEmpty + footer)

val companion = renderCompanion(name, schema.fields)
caseClass + companion
}

private def renderCompanion(name: String, fields: List[Field]): Doc = {
private def renderCompanion(name: String, fields: List[Record.Field]): Doc = {
val header = Doc.text("object") + Doc.space + Doc.text(name) + Doc.space + Doc.char('{')
val footer = Doc.char('}')
val nestedFields = fields
.flatMap { f =>
f.schema match {
case record: Record =>
val n = record.name.orElse(Some(toUpperCamel(f.name)))
Some(n.get -> renderRecord(record.copy(name = n)))
case enum: Enum =>
val n = enum.name.orElse(Some(toUpperCamel(f.name)))
Some(n.get -> renderEnum(enum.copy(name = n)))
case _ =>
None
Some(record.name -> renderRecord(RenderContext(f.name, Some(name)))(record))
case enum: Primitive.Enum =>
Some(enum.name -> renderEnum(RenderContext(f.name, Some(name)))(enum))
case _ => None
}
}
.groupBy(_._1)
Expand All @@ -71,8 +79,9 @@ object SchemaPrinter {
}
}

private def renderEnum(schema: Enum): Doc = {
val header = Doc.text("object") + Doc.space + Doc.text(schema.name.get) + Doc.space +
private def renderEnum(ctx: RenderContext)(schema: Primitive.Enum): Doc = {
val name = schema.name.getOrElse(toUpperCamel(ctx.field))
val header = Doc.text("object") + Doc.space + Doc.text(name) + Doc.space +
Doc.text("extends") + Doc.space + Doc.text("Enumeration") + Doc.space +
Doc.char('{')
val footer = Doc.char('}')
Expand All @@ -86,19 +95,26 @@ object SchemaPrinter {

nested(header, body, footer)
}
private def renderOwnerPrefix(ctx: RenderContext): Doc =
ctx.owner.fold(Doc.empty)(o => Doc.text(o) + Doc.char('.'))

private def renderField(name: String, field: Field): Doc = {
val rawType = field.schema match {
case p: Primitive => Doc.text(p.toString)
case r: Record => Doc.text(name + "." + r.name.getOrElse(toUpperCamel(field.name)))
case e: Enum => Doc.text(name + "." + e.name.getOrElse(toUpperCamel(field.name)))
}
val tpe = field.repetition match {
case Required => rawType
case Optional => Doc.text("Option") + Doc.char('[') + rawType + Doc.char(']')
case Repeated => Doc.text("List") + Doc.char('[') + rawType + Doc.char(']')
}
Doc.text(quoteIdentifier(field.name)) + Doc.char(':') + Doc.space + tpe
private def renderType(ctx: RenderContext)(s: Schema): Doc = s match {
case Optional(s) =>
Doc.text("Option") + Doc.char('[') + renderType(ctx)(s) + Doc.char(']')
case Repeated(s) =>
Doc.text("List") + Doc.char('[') + renderType(ctx)(s) + Doc.char(']')
case Mapped(k, v) =>
val keyType = renderType(ctx)(k)
val valueType = renderType(ctx)(v)
Doc.text("Map") + Doc.char('[') + keyType + Doc.char(',') + Doc.space + valueType + Doc.char(
']'
)
case Record(name, _, _) =>
renderOwnerPrefix(ctx) + Doc.text(name.getOrElse(toUpperCamel(ctx.field)))
case Primitive.Enum(name, _, _) =>
renderOwnerPrefix(ctx) + Doc.text(name.getOrElse(toUpperCamel(ctx.field)))
case p: Primitive =>
Doc.text(p.toString)
}

private def nested(header: Doc, body: Doc, footer: Doc): Doc =
Expand All @@ -113,7 +129,7 @@ object SchemaPrinter {
}
}

private def toUpperCamel(name: String): String = {
private[tools] def toUpperCamel(name: String): String = {
var allLower = true
var allUpper = true
var hasHyphen = false
Expand Down
Loading

0 comments on commit 45ab819

Please sign in to comment.