Skip to content

Commit

Permalink
fix tests on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
mkljakubowski committed Dec 19, 2024
1 parent 56e1443 commit 1f681c0
Show file tree
Hide file tree
Showing 27 changed files with 782 additions and 785 deletions.
36 changes: 19 additions & 17 deletions avrohugger-core/src/main/scala/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ package avrohugger

import avrohugger.format.abstractions.SourceFormat
import avrohugger.format._
import avrohugger.generators.{FileGenerator, StringGenerator}
import avrohugger.input.parsers.{FileInputParser, StringInputParser}
import avrohugger.generators.{ FileGenerator, StringGenerator }
import avrohugger.input.parsers.{ FileInputParser, StringInputParser }
import avrohugger.matchers.TypeMatcher
import avrohugger.types.AvroScalaTypes
import avrohugger.stores.{ClassStore, SchemaStore}
import org.apache.avro.{Protocol, Schema}
import avrohugger.stores.{ ClassStore, SchemaStore }
import org.apache.avro.{ Protocol, Schema }
import java.io.File

// Unable to overload this class' methods because outDir uses a default value
case class Generator(format: SourceFormat,
avroScalaCustomTypes: Option[AvroScalaTypes] = None,
avroScalaCustomNamespace: Map[String, String] = Map.empty,
restrictedFieldNumber: Boolean = false,
classLoader: ClassLoader = Thread.currentThread.getContextClassLoader,
targetScalaPartialVersion: String = avrohugger.internal.ScalaVersion.version) {
avroScalaCustomTypes: Option[AvroScalaTypes] = None,
avroScalaCustomNamespace: Map[String, String] = Map.empty,
restrictedFieldNumber: Boolean = false,
classLoader: ClassLoader = Thread.currentThread.getContextClassLoader,
targetScalaPartialVersion: String = avrohugger.internal.ScalaVersion.version) {

val avroScalaTypes = avroScalaCustomTypes.getOrElse(format.defaultTypes)
val defaultOutputDir = "target/generated-sources"
Expand All @@ -25,20 +25,22 @@ case class Generator(format: SourceFormat,
lazy val schemaParser = new Schema.Parser
val classStore = new ClassStore
val schemaStore = new SchemaStore
val fileGenerator = new FileGenerator
val stringGenerator = new StringGenerator
val typeMatcher = new TypeMatcher(avroScalaTypes, avroScalaCustomNamespace)

//////////////// methods for writing definitions out to file /////////////////
def schemaToFile(
schema: Schema,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.schemaToFile(
fileGenerator.schemaToFile(
schema, outDir, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def protocolToFile(
protocol: Protocol,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.protocolToFile(
fileGenerator.protocolToFile(
protocol,
outDir,
format,
Expand All @@ -52,7 +54,7 @@ case class Generator(format: SourceFormat,
def stringToFile(
schemaStr: String,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.stringToFile(
fileGenerator.stringToFile(
schemaStr,
outDir,
format,
Expand All @@ -67,7 +69,7 @@ case class Generator(format: SourceFormat,
def fileToFile(
inFile: File,
outDir: String = defaultOutputDir): Unit = {
FileGenerator.fileToFile(
fileGenerator.fileToFile(
inFile,
outDir,
format,
Expand All @@ -82,17 +84,17 @@ case class Generator(format: SourceFormat,

//////// methods for writing to a list of definitions in String format ///////
def schemaToStrings(schema: Schema): List[String] = {
StringGenerator.schemaToStrings(
stringGenerator.schemaToStrings(
schema, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def protocolToStrings(protocol: Protocol): List[String] = {
StringGenerator.protocolToStrings(
stringGenerator.protocolToStrings(
protocol, format, classStore, schemaStore, typeMatcher, restrictedFieldNumber, targetScalaPartialVersion)
}

def stringToStrings(schemaStr: String): List[String] = {
StringGenerator.stringToStrings(
stringGenerator.stringToStrings(
schemaStr,
format,
classStore,
Expand All @@ -104,7 +106,7 @@ case class Generator(format: SourceFormat,
}

def fileToStrings(inFile: File): List[String] = {
StringGenerator.fileToStrings(
stringGenerator.fileToStrings(
inFile,
format,
classStore,
Expand Down
80 changes: 40 additions & 40 deletions avrohugger-core/src/main/scala/format/abstractions/Importer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,38 @@ trait Importer {
// gets enum schemas which may be dependencies
def getEnumSchemas(
topLevelSchemas: List[Schema],
alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
alreadyImported: Set[Schema] = Set.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: Set[Schema]) = getRecordSchemas(List(s), us)

topLevelSchemas
.flatMap(schema => {
schema.getType match {
case RECORD =>
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(s, alreadyImported :+ s))
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(s, alreadyImported + s))
Seq(schema) ++ fieldSchemasWithChildSchemas
case ENUM =>
Seq(schema)
case UNION =>
schema.getTypes().asScala
.find(s => s.getType != NULL).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
.find(s => s.getType != NULL).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case MAP =>
Seq(schema.getValueType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getValueType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case ARRAY =>
Seq(schema.getElementType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getElementType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case _ =>
Seq.empty[Schema]
}
})
.filter(schema => schema.getType == ENUM)
.distinct
.toList
}

def getFixedSchemas(topLevelSchemas: List[Schema]): List[Schema] =
Expand All @@ -88,8 +88,7 @@ trait Importer {
})
.filter(_.getType == FIXED)
.distinct
.toList


def getFieldSchemas(schema: Schema): List[Schema] = {
schema.getFields().asScala.toList.map(field => field.schema)
}
Expand Down Expand Up @@ -126,47 +125,48 @@ trait Importer {

def requiresImportDef(schema: Schema): Boolean = {
(isRecord(schema) || isEnum(schema) || isFixed(schema)) &&
checkNamespace(schema).isDefined &&
checkNamespace(schema) != namespace
checkNamespace(schema).isDefined &&
checkNamespace(schema) != namespace
}

recordSchemas
.filter(schema => requiresImportDef(schema))
.groupBy(schema => checkNamespace(schema).getOrElse(schema.getNamespace))
.toList
.map(group => group match {
case(packageName, fields) => asImportDef(packageName, fields)
case (packageName, fields) => asImportDef(packageName, fields)
})
}

// gets record schemas which may be dependencies
def getRecordSchemas(
topLevelSchemas: List[Schema],
alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
alreadyImported: Set[Schema] = Set.empty[Schema]): List[Schema] = {
def nextSchemas(s: Schema, us: Set[Schema]) = getRecordSchemas(List(s), us)

topLevelSchemas
.flatMap(schema => {
schema.getType match {
case RECORD =>
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(s, alreadyImported :+ s))
val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(s, alreadyImported + s))
Seq(schema) ++ fieldSchemasWithChildSchemas
case ENUM =>
Seq(schema)
case UNION =>
schema.getTypes().asScala
.find(s => s.getType != NULL).toSeq
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
.find(s => s.getType != NULL).toSet
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case MAP =>
Seq(schema.getValueType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getValueType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case ARRAY =>
Seq(schema.getElementType)
.filter(s => alreadyImported.contains(s))
.flatMap(s => nextSchemas(schema, alreadyImported :+ s))
Set(schema.getElementType)
.intersect(alreadyImported)
.flatMap(s => nextSchemas(schema, alreadyImported + s))
case _ =>
Seq.empty[Schema]
}
Expand All @@ -177,23 +177,23 @@ trait Importer {
}

def getTopLevelSchemas(
schemaOrProtocol: Either[Schema, Protocol],
schemaOrProtocol: Either[Schema, Protocol],
schemaStore: SchemaStore,
typeMatcher: TypeMatcher): List[Schema] = {
schemaOrProtocol match {
case Left(schema) =>
schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
schema :: (NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
case Right(protocol) => protocol.getTypes().asScala.toList.flatMap(schema => {
schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
schema :: (NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
})
}

}

def isFixed(schema: Schema): Boolean = ( schema.getType == FIXED )
def isFixed(schema: Schema): Boolean = (schema.getType == FIXED)

def isEnum(schema: Schema): Boolean = ( schema.getType == ENUM )
def isEnum(schema: Schema): Boolean = (schema.getType == ENUM)

def isRecord(schema: Schema): Boolean = ( schema.getType == RECORD )
def isRecord(schema: Schema): Boolean = (schema.getType == RECORD)

}
Loading

0 comments on commit 1f681c0

Please sign in to comment.