Skip to content

Commit

Permalink
var additions
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Nov 23, 2023
1 parent 6d50ca3 commit 6ab4f39
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 47 deletions.
181 changes: 134 additions & 47 deletions extended/src/main/java/apoc/agg/CollAggregationExtended.java
Original file line number Diff line number Diff line change
@@ -1,60 +1,146 @@
package apoc.agg;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.math3.stat.descriptive.UnivariateStatistic;
import org.apache.commons.math3.stat.descriptive.moment.*;
import org.apache.commons.math3.stat.descriptive.rank.*;
import org.apache.commons.math3.stat.descriptive.summary.*;
import org.neo4j.procedure.*;

public class CollAggregationExtended {
/*
todo: step 1 https://duckdb.org/docs/sql/aggregates#general-aggregate-functions
any_value(arg) OK
arg_max(arg, val) NO -> apoc.agg.maxItems()
arg_min(arg, val) NO -> apoc.agg.minItems()
avg(arg) no -> avg() in Cypher
bit_and(arg) Returns the bitwise AND of all bits in a given expression . bit_and(A) -
bit_or(arg) Returns the bitwise OR of all bits in a given expression. bit_or(A) -
bit_xor(arg) Returns the bitwise XOR of all bits in a given expression. bit_xor(A) -
bitstring_agg(arg) Returns a bitstring with bits set for each distinct value. bitstring_agg(A) -
bool_and(arg) Returns true if every input value is true, otherwise false. bool_and(A) -
bool_or(arg) Returns true if any input value is true, otherwise false. bool_or(A) -
count(arg) NO
favg(arg) NO
first(arg) --> apoc.agg.first
fsum(arg) --> sum()
geomean(arg) Calculates the geometric mean for all tuples in arg. geomean(A) geometric_mean(A)
histogram(arg) Returns a MAP of key-value pairs representing buckets and counts. histogram(A) -
last(arg) --> apoc.agg.last
list(arg) --> collect()
max(arg) NO -> max()
min(arg) NO -> min()
product(arg) NO -> apoc.agg.product()
string_agg(arg, sep) Concatenates the column string values with a separator string_agg(S, ',') group_concat(arg, sep), listagg(arg, sep)
sum(arg) --> sum()
https://duckdb.org/docs/sql/aggregates#ordered-set-aggregate-functions --> NO percentileCont() and percentileDisc()
TODO IN ANOTHER PR MAYBE:
https://duckdb.org/docs/sql/aggregates#approximate-aggregates
https://duckdb.org/docs/sql/aggregates#statistical-aggregates
*/
public static final String BITWISE_OPERATOR_NOT_DEFINED = "Bitwise operator not defined";

@UserAggregationFunction("apoc.agg.statisticalOperation")
@Description("TODO")
public StatisticalOperation statisticalOperation() {
return new StatisticalOperation();
}

public static UnivariateStatistic from(String type) {
return switch (type.toUpperCase()) {
case "SUM" -> new Sum();
case "SUM_OF_SQUARES" -> new SumOfSquares();
case "PRODUCT" -> new Product();
case "SUM_OF_LOGS" -> new SumOfLogs();
case "MIN" -> new Min();
case "MAX" -> new Max();
case "MEAN" -> new Mean();
case "VARIANCE" -> new Variance();
case "PERCENTILE" -> new Percentile();
case "GEOMETRIC_MEAN" -> new GeometricMean();
case "SKEWNESS" -> new Skewness();
case "STANDARD_DEVIATION" -> new StandardDeviation();
case "SECOND_MOMENT" -> new SecondMoment();
case "KURTOSIS" -> new Kurtosis();
case "SEMI_VARIANCE" -> new SemiVariance();
default -> throw new RuntimeException("Invalid statistical operation");
};
}

public static class StatisticalOperation {

private UnivariateStatistic operation;
private long begin;
private long length;
private final List<Double> valueList = new ArrayList<>();

@UserAggregationUpdate
public void update(@Name("value") double current,
@Name("operation") String operation,
@Name(value = "begin", defaultValue = "-1") long begin,
@Name(value = "length", defaultValue = "-1") long length) {
if (this.operation == null) {
this.operation = CollAggregationExtended.from(operation);
this.begin = begin;
this.length = length;
}

this.valueList.add(current);
}

@UserAggregationResult
public double result() {
double[] doubles = valueList.stream().mapToDouble(Double::doubleValue).toArray();
if (begin == -1L || length == -1L) {
return operation.evaluate(doubles);
}
return operation.evaluate(doubles, (int) begin, (int) length);
}
}

/*
VALID:
any_value
bit_and
bit_or
bit_xor
bitstring_agg
bool_and --> chiamarla apoc.agg.all()
bool_or --> chiamarla apoc.agg.any()
geomean --> ??
histogram --> ??
string_agg --> apoc.agg.join
@UserAggregationFunction("apoc.agg.binaryString")
@Description("TODO")
public BitStringFunction binaryString() {
return new BitStringFunction();
}

public static class BitStringFunction {

private final Set<String> value = new HashSet<>();

@UserAggregationUpdate
public void update(@Name("value") Long current) {
String binaryString = Long.toBinaryString(current);
this.value.add(binaryString);
}

@UserAggregationResult
public List<String> result() {
return List.copyOf(value);
}
}

*/

@UserAggregationFunction("apoc.agg.bitwise")
@Description("TODO")
public BitwiseFunction bitwise() {
return new BitwiseFunction();
}

public static class BitwiseFunction {

private Long value;

@UserAggregationUpdate
public void update(@Name("value") Long current, @Name("operator") final String operator) {
this.value = bitwiseOperation(this.value, operator, current);
}

@UserAggregationResult
public Long result() {
return value;
}
}

/**
* Similar to `apoc.bitwise.BitwiseOperations.java` (just with switch operator improved)
* and without `NOT` operator, which it doesn't make much sense with an aggregation function
*/
public static Long bitwiseOperation(Long a, String operator, Long b) {
if (a == null) {
return b;
}
if (operator == null || operator.isEmpty()) {
throw new RuntimeException(BITWISE_OPERATOR_NOT_DEFINED);
}
if (!operator.equals("~") && b == null) {
return null;
}
return switch (operator.toLowerCase()) {
case "&", "and" -> a & b;
case "|", "or" -> a | b;
case "^", "xor" -> a ^ b;
case ">>", "right shift" -> a >> b;
case ">>>", "right shift unsigned" -> a >>> b;
case "<<", "left shift" -> a << b;
default -> throw new RuntimeException("Invalid bitwise operator : '%s'".formatted(operator));
};
}


@UserAggregationFunction("apoc.agg.any")
@Description("TODO")
Expand All @@ -69,6 +155,7 @@ public static class AnyFunction {

@UserAggregationUpdate
public void update(@Name("value") Boolean value) {

if (!this.value && value == null) {
isNull = true;
} else if (Boolean.TRUE.equals(value)) {
Expand Down
93 changes: 93 additions & 0 deletions extended/src/test/java/apoc/agg/CollAggregationExtendedTest.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package apoc.agg;

import static apoc.agg.CollAggregationExtended.BITWISE_OPERATOR_NOT_DEFINED;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import apoc.util.TestUtil;
import org.junit.AfterClass;
Expand Down Expand Up @@ -99,4 +103,93 @@ private static void testAnyAllCommon(List<Boolean> list, Consumer<Map<String, Ob
Map.of("list", list),
rowConsumer);
}

@Test
public void testBitwise() {
long first = 0b0011_1100L;
long second = 0b0000_1101L;
long third = 2_100L;
List<Long> list = List.of(first, second, third);

try {
testBitwiseCommon(list, 0, null);
fail("Should fail since the operator is null");
} catch (RuntimeException e) {
assertTrue(e.getMessage().contains(BITWISE_OPERATOR_NOT_DEFINED));
}

long expected = first >> second >> third;
testBitwiseCommon(list, expected, ">>");
testBitwiseCommon(list, expected, "right shift");

long expectedLeftShift = first << second << third;
testBitwiseCommon(list, expectedLeftShift, "<<");
testBitwiseCommon(list, expectedLeftShift, "left shift");

long expectedRightShiftUnsigned = first >>> second >>> third;
testBitwiseCommon(List.of(first, second), expectedRightShiftUnsigned, ">>>");
testBitwiseCommon(List.of(first, second), expectedRightShiftUnsigned, "right shift unsigned");

long expectedAnd = first & second & third;
testBitwiseCommon(list, expectedAnd, "&");
testBitwiseCommon(list, expectedAnd, "AND");

long expectedOr = first | second | third;
testBitwiseCommon(list, expectedOr, "OR");
testBitwiseCommon(list, expectedOr, "|");

long expectedXor = first ^ second ^ third;
testBitwiseCommon(list, expectedXor, "XOR");
testBitwiseCommon(list, expectedXor, "^");
}

private static void testBitwiseCommon(List<Long> list, long expected, String operator) {
testCall(db, "UNWIND $list as value \n" +
"RETURN apoc.agg.bitwise(value, $operator) as bitwise",
map("list", list, "operator", operator),
row -> assertEquals(expected, row.get("bitwise")));
}

@Test
public void testBinaryString() {
List<Long> list = List.of(0b0011_1100L, 0b0000_1101L, 2_100L);
testCall(db, "UNWIND $list as value \n" +
"RETURN apoc.agg.binaryString(value) as result",
map("list", list),
row -> {
List<Object> expected = List.of("111100","1101","100000110100");
assertEquals(expected, row.get("result"));
});
}

@Test
public void testStatisticalOperation() {
testStatisticalCommon("SUM", 123.5D);
testStatisticalCommon("SUM_OF_SQUARES", 123.5D);
testStatisticalCommon("PRODUCT", 123.5D);
testStatisticalCommon("SUM_OF_LOGS", 123.5D);
testStatisticalCommon("MIN", 123.5D);
testStatisticalCommon("MAX", 123.5D);
testStatisticalCommon("MEAN", 123.5D);
testStatisticalCommon("VARIANCE", 123.5D);
testStatisticalCommon("PERCENTILE", 123.5D);
testStatisticalCommon("GEOMETRIC_MEAN", 123.5D);
testStatisticalCommon("SKEWNESS", 123.5D);
testStatisticalCommon("STANDARD_DEVIATION", 123.5D);
testStatisticalCommon("SECOND_MOMENT", 123.5D);
testStatisticalCommon("KURTOSIS", 123.5D);
testStatisticalCommon("SEMI_VARIANCE", 123.5D);
}

private static void testStatisticalCommon(String operation, Double expected) {
List<Double> list = List.of(3.11D, 8.22D, 10.3D, 17.1D, 17.98D, 22.0D);
testCall(db, "UNWIND $list as value \n" +
"RETURN apoc.agg.statisticalOperation(value, $operation) as result",
map("list", list, "operation", operation),
row -> {
// TODO - COMPLETE ASSERTIONS
System.out.println("operation = " + operation + " -- row.get(\"result\") = " + row.get("result"));
// assertEquals(expected, row.get("result"));
});
}
}

0 comments on commit 6ab4f39

Please sign in to comment.