Skip to content

Commit

Permalink
Merge branch 'speedup-round2' of https://github.com/mkljakubowski/avr…
Browse files Browse the repository at this point in the history
…ohugger into mkljakubowski-speedup-round2
  • Loading branch information
julianpeeters committed Jan 9, 2025
2 parents 5728f96 + 14ead3e commit a3c6850
Showing 1 changed file with 50 additions and 43 deletions.
93 changes: 50 additions & 43 deletions avrohugger-core/src/main/scala/input/NestedSchemaExtractor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import stores.SchemaStore
import types.EnumAsScalaString

import org.apache.avro.Schema
import org.apache.avro.Schema.Type.{ARRAY, ENUM, FIXED, MAP, RECORD, UNION}
import org.apache.avro.Schema.Type.{ ARRAY, ENUM, FIXED, MAP, RECORD, UNION }

import scala.jdk.CollectionConverters._

Expand All @@ -16,53 +16,60 @@ object NestedSchemaExtractor {
schema: Schema,
schemaStore: SchemaStore,
typeMatcher: TypeMatcher): List[Schema] = {
def extract(
schema: Schema,
fieldPath: List[String] = List.empty): List[Schema] = {
var visitedSchemas = Set.empty[String]

schema.getType match {
case RECORD =>
val fields: List[Schema.Field] = schema.getFields().asScala.toList
val fieldSchemas: List[Schema] = fields.map(field => field.schema)
def flattenSchema(fieldSchema: Schema): List[Schema] = {
fieldSchema.getType match {
case ARRAY => flattenSchema(fieldSchema.getElementType)
case MAP => flattenSchema(fieldSchema.getValueType)
case RECORD => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
// if we've already seen this schema (recursive schemas) don't traverse further
else if (fieldPath.contains(fieldSchema.getFullName)) List()
else fieldSchema :: extract(fieldSchema, fieldSchema.getFullName :: fieldPath)
}
case UNION => fieldSchema.getTypes().asScala.toList.flatMap(x => flattenSchema(x))
case ENUM => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
}
case FIXED => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
def extract(schema: Schema): List[Schema] = {
if (visitedSchemas.contains(schema.getFullName))
List()
else {
visitedSchemas += schema.getFullName
schema.getType match {
case RECORD =>
val fields: List[Schema.Field] = schema.getFields().asScala.toList
val fieldSchemas: List[Schema] = fields.map(field => field.schema)

def flattenSchema(fieldSchema: Schema): List[Schema] = {
fieldSchema.getType match {
case ARRAY => flattenSchema(fieldSchema.getElementType)
case MAP => flattenSchema(fieldSchema.getValueType)
case RECORD => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
// if we've already seen this schema (recursive schemas) don't traverse further
else fieldSchema :: extract(fieldSchema)

}
case UNION => fieldSchema.getTypes().asScala.toList.flatMap(x => flattenSchema(x))
case ENUM => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
}
case FIXED => {
// if the field schema is one that has already been stored, use that one
if (schemaStore.schemas.contains(fieldSchema.getFullName)) List()
else List(fieldSchema)
}
case _ => List(fieldSchema)
}
case _ => List(fieldSchema)
}
}
val flatSchemas = fieldSchemas.flatMap(fieldSchema => flattenSchema(fieldSchema))
def topLevelTypes(schema: Schema) = {
if (typeMatcher.avroScalaTypes.`enum` == EnumAsScalaString) (schema.getType == RECORD | schema.getType == FIXED)
else (schema.getType == RECORD | schema.getType == ENUM | schema.getType == FIXED)
}
val nestedTopLevelSchemas = flatSchemas.filter(topLevelTypes)
nestedTopLevelSchemas
case ENUM => List(schema)
case FIXED => List(schema)
case _ => Nil
}

fieldSchemas
.flatMap(flattenSchema)
.filter { schema =>
if (typeMatcher.avroScalaTypes.`enum` == EnumAsScalaString)
schema.getType == RECORD | schema.getType == FIXED
else
schema.getType == RECORD | schema.getType == ENUM | schema.getType == FIXED
}
case ENUM => List(schema)
case FIXED => List(schema)
case _ => Nil
}
}
}

schema::extract(schema)
schema :: extract(schema)
}
}

0 comments on commit a3c6850

Please sign in to comment.