Skip to content

Commit

Permalink
Add in ability to generate data for all combinations of specified fie…
Browse files Browse the repository at this point in the history
…ld values
  • Loading branch information
pflooky committed Sep 12, 2024
1 parent d249a35 commit 8d1973f
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package io.github.datacatering.datacaterer.core.generator

import io.github.datacatering.datacaterer.api.model.Constants.SQL_GENERATOR
import io.github.datacatering.datacaterer.api.model.Constants.{ONE_OF_GENERATOR, SQL_GENERATOR}
import io.github.datacatering.datacaterer.api.model.{Field, PerColumnCount, Step}
import io.github.datacatering.datacaterer.core.exception.InvalidStepCountGeneratorConfigurationException
import io.github.datacatering.datacaterer.core.generator.provider.DataGenerator
import io.github.datacatering.datacaterer.core.generator.provider.OneOfDataGenerator.RandomOneOfDataGenerator
import io.github.datacatering.datacaterer.core.model.Constants._
import io.github.datacatering.datacaterer.core.util.GeneratorUtil.{applySqlExpressions, getDataGenerator}
import io.github.datacatering.datacaterer.core.util.ObjectMapperUtil
import io.github.datacatering.datacaterer.core.util.PlanImplicits.FieldOps
import net.datafaker.Faker
import org.apache.log4j.Logger
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand All @@ -21,27 +23,33 @@ case class Holder(__index_inc: Long)

class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession) {

private val LOGGER = Logger.getLogger(getClass.getName)
private val OBJECT_MAPPER = ObjectMapperUtil.jsonObjectMapper
registerSparkFunctions()

def generateDataForStep(step: Step, dataSourceName: String, startIndex: Long, endIndex: Long): DataFrame = {
//need to have separate code for generating all possible combinations
val structFieldsWithDataGenerators = step.schema.fields.map(getStructWithGenerators).getOrElse(List())
val indexedDf = sparkSession.createDataFrame(Seq.range(startIndex, endIndex).map(Holder))
generateDataViaSql(structFieldsWithDataGenerators, step, indexedDf)
.alias(s"$dataSourceName.${step.name}")
}

private def generateDataViaSql(dataGenerators: List[DataGenerator[_]], step: Step, indexedDf: DataFrame): DataFrame = {
val structType = StructType(dataGenerators.map(_.structField))
val genSqlExpression = dataGenerators.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")
val df = indexedDf.selectExpr(genSqlExpression: _*)
val allRecordsDf = if (step.options.contains("allCombinations") && step.options("allCombinations").equalsIgnoreCase("true")) {
generateCombinationRecords(dataGenerators, indexedDf)
} else {
val genSqlExpression = dataGenerators.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")
val df = indexedDf.selectExpr(genSqlExpression: _*)

val perColDf = step.count.perColumn
.map(perCol => generateRecordsPerColumn(dataGenerators, step, perCol, df))
.getOrElse(df)
if (!perColDf.storageLevel.useMemory) perColDf.cache()
step.count.perColumn
.map(perCol => generateRecordsPerColumn(dataGenerators, step, perCol, df))
.getOrElse(df)
}

val dfWithMetadata = attachMetadata(perColDf, structType)
if (!allRecordsDf.storageLevel.useMemory) allRecordsDf.cache()
val structType = StructType(dataGenerators.map(_.structField))
val dfWithMetadata = attachMetadata(allRecordsDf, structType)
val dfAllFields = attachMetadata(applySqlExpressions(dfWithMetadata), structType)
if (!dfAllFields.storageLevel.useMemory) dfAllFields.cache()
dfAllFields
Expand Down Expand Up @@ -100,6 +108,31 @@ class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession
explodeCount.select(PER_COLUMN_INDEX_COL + ".*", perColumnCount.columnNames: _*)
}

private def generateCombinationRecords(dataGenerators: List[DataGenerator[_]], indexedDf: DataFrame) = {
LOGGER.debug("Attempting to generate all combinations of 'oneOf' fields")
//TODO could be nested oneOf fields
val oneOfFields = dataGenerators
.filter(x => x.isInstanceOf[RandomOneOfDataGenerator] || x.options.contains(ONE_OF_GENERATOR))
val nonOneOfFields = dataGenerators.filter(x => !x.isInstanceOf[RandomOneOfDataGenerator] && !x.options.contains(ONE_OF_GENERATOR))

val oneOfFieldsSql = oneOfFields.map(field => {
val fieldValues = field.structField.metadata.getStringArray(ONE_OF_GENERATOR)
sparkSession.createDataFrame(Seq(1L).map(Holder))
.selectExpr(explode(typedlit(fieldValues)).as(field.structField.name).expr.sql)
})
val nonOneOfFieldsSql = nonOneOfFields.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")

if (oneOfFields.nonEmpty) {
LOGGER.debug("Found fields defined with 'oneOf', attempting to create all combinations of possible values")
val pairwiseCombinations = oneOfFieldsSql.reduce((a, b) => a.crossJoin(b))
val selectExpr = pairwiseCombinations.columns.toList ++ nonOneOfFieldsSql
pairwiseCombinations.selectExpr(selectExpr: _*)
} else {
LOGGER.debug("No fields defined with 'oneOf', unable to create all possible combinations")
indexedDf
}
}

private def generateDataWithSchema(dataGenerators: List[DataGenerator[_]]): UserDefinedFunction = {
udf((sqlGen: Int) => {
(1L to sqlGen)
Expand Down Expand Up @@ -132,58 +165,59 @@ class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession
}

private def defineRandomLengthView(): Unit = {
sparkSession.sql(s"""WITH lengths AS (
| SELECT sequence(1, $DATA_CATERER_RANDOM_LENGTH_MAX_VALUE) AS length_list
|),
|
|-- Explode the sequence into individual length values
|exploded_lengths AS (
| SELECT explode(length_list) AS length
| FROM lengths
|),
|
|-- Create the heuristic cumulative distribution dynamically
|length_distribution AS (
| SELECT
| length,
| CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END AS weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER (ORDER BY length) AS cumulative_weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER () AS total_weight
| FROM exploded_lengths
|),
|
|-- Calculate cumulative probabilities
|length_probabilities AS (
| SELECT
| length,
| cumulative_weight / total_weight AS cumulative_prob
| FROM length_distribution
|),
|
|-- Select a single random length based on the heuristic distribution
|random_length AS (
| SELECT
| length
| FROM length_probabilities
| WHERE cumulative_prob >= rand()
| ORDER BY cumulative_prob
| LIMIT 1
|)
|
|-- Final query to get the single random length
|SELECT * FROM random_length;""".stripMargin)
sparkSession.sql(
s"""WITH lengths AS (
| SELECT sequence(1, $DATA_CATERER_RANDOM_LENGTH_MAX_VALUE) AS length_list
|),
|
|-- Explode the sequence into individual length values
|exploded_lengths AS (
| SELECT explode(length_list) AS length
| FROM lengths
|),
|
|-- Create the heuristic cumulative distribution dynamically
|length_distribution AS (
| SELECT
| length,
| CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END AS weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER (ORDER BY length) AS cumulative_weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER () AS total_weight
| FROM exploded_lengths
|),
|
|-- Calculate cumulative probabilities
|length_probabilities AS (
| SELECT
| length,
| cumulative_weight / total_weight AS cumulative_prob
| FROM length_distribution
|),
|
|-- Select a single random length based on the heuristic distribution
|random_length AS (
| SELECT
| length
| FROM length_probabilities
| WHERE cumulative_prob >= rand()
| ORDER BY cumulative_prob
| LIMIT 1
|)
|
|-- Final query to get the single random length
|SELECT * FROM random_length;""".stripMargin)
.createOrReplaceTempView(DATA_CATERER_RANDOM_LENGTH)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package io.github.datacatering.datacaterer.core.generator.provider

import io.github.datacatering.datacaterer.api.model.Constants.{ARRAY_MAXIMUM_LENGTH, ARRAY_MINIMUM_LENGTH, ENABLED_EDGE_CASE, ENABLED_NULL, IS_UNIQUE, PROBABILITY_OF_EDGE_CASE, PROBABILITY_OF_NULL, RANDOM_SEED, STATIC}
import io.github.datacatering.datacaterer.api.model.generator.BaseGenerator
import io.github.datacatering.datacaterer.core.model.Constants.DATA_CATERER_RANDOM_LENGTH
import net.datafaker.Faker
import org.apache.spark.sql.functions.{expr, rand, when}
import org.apache.spark.sql.types.StructField

import java.util.regex.Pattern
import scala.annotation.tailrec
import scala.collection.mutable
import scala.language.higherKinds
Expand Down Expand Up @@ -49,7 +51,9 @@ trait DataGenerator[T] extends BaseGenerator[T] with Serializable {
.expr.sql
case _ => baseSqlExpression
}
replaceLambdaFunction(expression)
val replaceLambda = replaceLambdaFunction(expression)
val replaceSubScalar = replaceSubScalarFunction(replaceLambda, baseSqlExpression)
replaceSubScalar
}

def generateWrapper(count: Int = 0): T = {
Expand All @@ -75,16 +79,33 @@ trait DataGenerator[T] extends BaseGenerator[T] with Serializable {
}
}

@tailrec
private def replaceLambdaFunction(sql: String): String = {
val lambdaRegex = ".*lambdafunction\\((.+?), i\\).*".r.pattern
val matcher = lambdaRegex.matcher(sql)
val replaceTargetFn: String => String = r => s"lambdafunction($r, i)"
val replacementFn: String => String = r => s"i -> $r"
replaceByRegex(sql, lambdaRegex, replaceTargetFn, replacementFn)
}

private def replaceSubScalarFunction(sql: String, originalSql: String): String = {
val lambdaRegex = ".*scalarsubquery\\((.*?)\\).*".r.pattern
val replaceTargetFn: String => String = r => s"scalarsubquery()"
val originalRegex = s".*\\(SELECT CAST\\((.+?) $DATA_CATERER_RANDOM_LENGTH\\).*".r.pattern
val matcher = originalRegex.matcher(originalSql)
if (matcher.matches()) {
val innerFunction = matcher.group(1)
val replace = sql.replace(s"lambdafunction($innerFunction, i)", s"i -> $innerFunction")
replaceLambdaFunction(replace)
val replacementFn: String => String = _ => s"(SELECT CAST(${matcher.group(1)} $DATA_CATERER_RANDOM_LENGTH)"
replaceByRegex(sql, lambdaRegex, replaceTargetFn, replacementFn)
} else sql
}

@tailrec
private def replaceByRegex(text: String, pattern: Pattern, replaceTargetFn: String => String, replacementFn: String => String): String = {
val matcher = pattern.matcher(text)
if (matcher.matches()) {
val innerFunction = matcher.group(1)
val replace = text.replace(replaceTargetFn(innerFunction), replacementFn(innerFunction))
replaceByRegex(replace, pattern, replaceTargetFn, replacementFn)
} else text
}
}

trait NullableDataGenerator[T >: Null] extends DataGenerator[T] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object GeneratorUtil {
case x => throw new UnsupportedDataGeneratorType(x)
}
} else {
LOGGER.debug(s"No generator defined, will default to random generator, field-name=${structField.name}")
RandomDataGenerator.getGeneratorForStructField(structField, faker)
LOGGER.debug(s"No generator defined, will get type of generator based on field options, field-name=${structField.name}")
getDataGenerator(structField, faker)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.github.datacatering.datacaterer.core.generator

import io.github.datacatering.datacaterer.api.model.Constants.{MAXIMUM_LENGTH, MINIMUM_LENGTH}
import io.github.datacatering.datacaterer.api.model.Constants.{MAXIMUM_LENGTH, MINIMUM_LENGTH, ONE_OF_GENERATOR, RANDOM_GENERATOR, REGEX_GENERATOR, SQL_GENERATOR}
import io.github.datacatering.datacaterer.api.model.{Count, Field, Generator, PerColumnCount, Schema, Step}
import io.github.datacatering.datacaterer.core.util.SparkSuite
import net.datafaker.Faker
Expand All @@ -14,11 +14,11 @@ class DataGeneratorFactoryTest extends SparkSuite {
private val dataGeneratorFactory = new DataGeneratorFactory(new Faker() with Serializable)
private val schema = Schema(Some(
List(
Field("id", Some("string"), Some(Generator("random", Map(MINIMUM_LENGTH -> "20", MAXIMUM_LENGTH -> "25")))),
Field("id", Some("string"), Some(Generator(RANDOM_GENERATOR, Map(MINIMUM_LENGTH -> "20", MAXIMUM_LENGTH -> "25")))),
Field("amount", Some("double")),
Field("debit_credit", Some("string"), Some(Generator("oneOf", Map("oneOf" -> List("D", "C"))))),
Field("name", Some("string"), Some(Generator("regex", Map("regex" -> "[A-Z][a-z]{2,6} [A-Z][a-z]{2,8}")))),
Field("code", Some("int"), Some(Generator("sql", Map("sql" -> "CASE WHEN debit_credit == 'D' THEN 1 ELSE 0 END")))),
Field("debit_credit", Some("string"), Some(Generator(ONE_OF_GENERATOR, Map("oneOf" -> List("D", "C"))))),
Field("name", Some("string"), Some(Generator(REGEX_GENERATOR, Map("regex" -> "[A-Z][a-z]{2,6} [A-Z][a-z]{2,8}")))),
Field("code", Some("int"), Some(Generator(SQL_GENERATOR, Map("sql" -> "CASE WHEN debit_credit == 'D' THEN 1 ELSE 0 END")))),
)
))
private val simpleSchema = Schema(Some(List(Field("id"))))
Expand Down Expand Up @@ -92,4 +92,37 @@ class DataGeneratorFactoryTest extends SparkSuite {
val sampleRows = df.filter(_.getAs[String]("id") == sampleId)
assert(sampleRows.count() == 1L)
}

test("Can generate data with all possible oneOf combinations enabled in step") {
val step = Step("transaction", "parquet", Count(),
Map("path" -> "sample/output/parquet/transactions", "allCombinations" -> "true"), schema)

val df = dataGeneratorFactory.generateDataForStep(step, "parquet", 0, 15)
df.cache()

assertResult(2L)(df.count())
val idx = df.columns.indexOf("debit_credit")
assert(df.collect().exists(r => r.getString(idx) == "D"))
assert(df.collect().exists(r => r.getString(idx) == "C"))
}

test("Can generate data with all possible oneOf combinations enabled in step with multiple oneOf fields") {
val statusField = Field("status", Some("string"),
Some(Generator(ONE_OF_GENERATOR, Map("oneOf" -> List("open", "closed", "suspended")))))
val fieldsWithStatus = Some(schema.fields.get ++ List(statusField))
val step = Step("transaction", "parquet", Count(),
Map("path" -> "sample/output/parquet/transactions", "allCombinations" -> "true"), schema.copy(fields = fieldsWithStatus))

val df = dataGeneratorFactory.generateDataForStep(step, "parquet", 0, 15)
df.cache()

assertResult(6L)(df.count())
val debitIdx = df.columns.indexOf("debit_credit")
val statusIdx = df.columns.indexOf("status")
assertResult(3)(df.collect().count(r => r.getString(debitIdx) == "D"))
assertResult(3)(df.collect().count(r => r.getString(debitIdx) == "C"))
assertResult(2)(df.collect().count(r => r.getString(statusIdx) == "open"))
assertResult(2)(df.collect().count(r => r.getString(statusIdx) == "closed"))
assertResult(2)(df.collect().count(r => r.getString(statusIdx) == "suspended"))
}
}

0 comments on commit 8d1973f

Please sign in to comment.