Skip to content

Commit

Permalink
fix: aggregation over other than numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
Pascal-Delange committed Jun 13, 2024
1 parent dbba4c2 commit 81b16da
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
26 changes: 20 additions & 6 deletions repositories/ingested_data_read_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type IngestedDataReadRepository interface {
exec Executor,
tableName string,
fieldName string,
fieldType models.DataType,
aggregator ast.Aggregator,
filters []ast.Filter,
) (any, error)
Expand Down Expand Up @@ -295,8 +296,13 @@ func queryWithDynamicColumnList(
return output, nil
}

func createQueryAggregated(exec Executor, tableName string,
fieldName string, aggregator ast.Aggregator, filters []ast.Filter,
func createQueryAggregated(
exec Executor,
tableName string,
fieldName string,
fieldType models.DataType,
aggregator ast.Aggregator,
filters []ast.Filter,
) (squirrel.SelectBuilder, error) {
var selectExpression string
if aggregator == ast.AGGREGATOR_COUNT_DISTINCT {
Expand All @@ -305,9 +311,11 @@ func createQueryAggregated(exec Executor, tableName string,
// COUNT(*) is a special case, as it does not take a field name (we do not want to count only non-null
// values of a field, but all rows in the table that match the filters)
selectExpression = "COUNT(*)"
} else {
} else if fieldType == models.Int {
// pgx will build a math/big.Int if we sum postgresql "bigint" (int64) values - we'd rather have a float64.
selectExpression = fmt.Sprintf("%s(%s)::float8", aggregator, fieldName)
} else {
selectExpression = fmt.Sprintf("%s(%s)", aggregator, fieldName)
}

qualifiedTableName := tableNameWithSchema(exec, tableName)
Expand All @@ -327,14 +335,20 @@ func createQueryAggregated(exec Executor, tableName string,
return query, nil
}

func (repo *IngestedDataReadRepositoryImpl) QueryAggregatedValue(ctx context.Context, exec Executor,
tableName string, fieldName string, aggregator ast.Aggregator, filters []ast.Filter,
func (repo *IngestedDataReadRepositoryImpl) QueryAggregatedValue(
ctx context.Context,
exec Executor,
tableName string,
fieldName string,
fieldType models.DataType,
aggregator ast.Aggregator,
filters []ast.Filter,
) (any, error) {
if err := validateClientDbExecutor(exec); err != nil {
return nil, err
}

query, err := createQueryAggregated(exec, tableName, fieldName, aggregator, filters)
query, err := createQueryAggregated(exec, tableName, fieldName, fieldType, aggregator, filters)
if err != nil {
return nil, fmt.Errorf("error while building SQL query: %w", err)
}
Expand Down
28 changes: 22 additions & 6 deletions repositories/ingested_data_read_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,14 @@ func TestIngestedDataGetDbFieldWithJoin(t *testing.T) {
}

func TestIngestedDataQueryAggregatedValueWithoutFilter(t *testing.T) {
query, err := createQueryAggregated(TransactionTest{}, utils.DummyTableNameFirst,
utils.DummyFieldNameForInt, ast.AGGREGATOR_AVG, []ast.Filter{})
query, err := createQueryAggregated(
TransactionTest{},
utils.DummyTableNameFirst,
utils.DummyFieldNameForInt,
models.Int,
ast.AGGREGATOR_AVG,
[]ast.Filter{},
)
assert.Empty(t, err)
sql, args, err := query.ToSql()
assert.Empty(t, err)
Expand All @@ -114,8 +120,13 @@ func TestIngestedDataQueryAggregatedValueWithoutFilter(t *testing.T) {
}

func TestIngestedDataQueryCountWithoutFilter(t *testing.T) {
query, err := createQueryAggregated(TransactionTest{}, utils.DummyTableNameFirst,
utils.DummyFieldNameForInt, ast.AGGREGATOR_COUNT, []ast.Filter{})
query, err := createQueryAggregated(
TransactionTest{},
utils.DummyTableNameFirst,
utils.DummyFieldNameForInt,
models.Int,
ast.AGGREGATOR_COUNT,
[]ast.Filter{})
assert.Empty(t, err)
sql, args, err := query.ToSql()
assert.Empty(t, err)
Expand All @@ -141,8 +152,13 @@ func TestIngestedDataQueryAggregatedValueWithFilter(t *testing.T) {
},
}

query, err := createQueryAggregated(TransactionTest{}, utils.DummyTableNameFirst,
utils.DummyFieldNameForInt, ast.AGGREGATOR_AVG, filters)
query, err := createQueryAggregated(
TransactionTest{},
utils.DummyTableNameFirst,
utils.DummyFieldNameForInt,
models.Int,
ast.AGGREGATOR_AVG,
filters)
assert.Empty(t, err)
sql, args, err := query.ToSql()
assert.Empty(t, err)
Expand Down
13 changes: 9 additions & 4 deletions usecases/ast_eval/evaluate/evaluate_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (a AggregatorEvaluator) Evaluate(ctx context.Context, arguments ast.Argumen
}
}

result, err := a.runQueryInRepository(ctx, tableName, fieldName, aggregator, filters)
result, err := a.runQueryInRepository(ctx, tableName, fieldName, fieldType, aggregator, filters)
if err != nil {
return MakeEvaluateError(errors.Wrap(err, "Error running aggregation query in repository"))
}
Expand All @@ -97,8 +97,13 @@ func (a AggregatorEvaluator) Evaluate(ctx context.Context, arguments ast.Argumen
return result, nil
}

func (a AggregatorEvaluator) runQueryInRepository(ctx context.Context, tableName string,
fieldName string, aggregator ast.Aggregator, filters []ast.Filter,
func (a AggregatorEvaluator) runQueryInRepository(
ctx context.Context,
tableName string,
fieldName string,
fieldType models.DataType,
aggregator ast.Aggregator,
filters []ast.Filter,
) (any, error) {
if a.ReturnFakeValue {
return DryRunQueryAggregatedValue(a.DataModel, tableName, fieldName, aggregator)
Expand All @@ -108,7 +113,7 @@ func (a AggregatorEvaluator) runQueryInRepository(ctx context.Context, tableName
if err != nil {
return nil, err
}
return a.IngestedDataReadRepository.QueryAggregatedValue(ctx, db, tableName, fieldName, aggregator, filters)
return a.IngestedDataReadRepository.QueryAggregatedValue(ctx, db, tableName, fieldName, fieldType, aggregator, filters)
}

func (a AggregatorEvaluator) defaultValueForAggregator(aggregator ast.Aggregator) (any, []error) {
Expand Down

0 comments on commit 81b16da

Please sign in to comment.