Skip to content

Commit

Permalink
feat(formula): BI-5525 add arr_intersect function (#696)
Browse files Browse the repository at this point in the history
* array intersect function added

* array intersect function tests added

* comment removed

* array intersection function removed from docs & suggest
  • Loading branch information
juliarbkv authored Nov 22, 2024
1 parent 70e91af commit e37c934
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,4 +286,6 @@
V(D.CLICKHOUSE, sa.func.arrayJoin),
]
),
# intersect
base.FuncArrayIntersect(variants=[V(D.CLICKHOUSE, sa.func.arrayIntersect)]),
]
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,15 @@ def _array_notcontains(array: ClauseElement, value: ClauseElement) -> ClauseElem
V(D.POSTGRESQL, lambda arr: sa.func.unnest(arr)),
]
),
# intersect
base.FuncArrayIntersect(
variants=[
V(
D.POSTGRESQL,
lambda *arrays: sa.func.array(
sa.intersect(*[sa.select(sa.func.unnest(arr)) for arr in arrays]).scalar_subquery()
),
)
]
),
]
16 changes: 16 additions & 0 deletions lib/dl_formula/dl_formula/definitions/functions_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FuncStartswith,
)
from dl_formula.definitions.literals import un_literal
from dl_formula.definitions.scope import Scope
from dl_formula.definitions.type_strategy import (
Fixed,
FromArgs,
Expand Down Expand Up @@ -556,6 +557,19 @@ class FuncArrayRemoveDefault(FuncArrayRemoveBase):
]


class FuncArrayIntersect(ArrayFunction):
name = "arr_intersect"
scopes = Function.scopes & ~Scope.SUGGESTED & ~Scope.DOCUMENTED
arg_names = ["array_1", "array_2", "array_3"]
arg_cnt = None
argument_types = [
ArgTypeForAll(DataType.ARRAY_STR, require_type_match={DataType.ARRAY_STR, DataType.CONST_ARRAY_STR}),
ArgTypeForAll(DataType.ARRAY_INT, require_type_match={DataType.ARRAY_INT, DataType.CONST_ARRAY_INT}),
ArgTypeForAll(DataType.ARRAY_FLOAT, require_type_match={DataType.ARRAY_FLOAT, DataType.CONST_ARRAY_FLOAT}),
]
return_type = FromArgs(0)


DEFINITIONS_ARRAY = [
# arr_avg
FuncArrayAvg,
Expand Down Expand Up @@ -628,4 +642,6 @@ class FuncArrayRemoveDefault(FuncArrayRemoveBase):
FuncUnnestArrayFloat,
FuncUnnestArrayInt,
FuncUnnestArrayStr,
# intersect
FuncArrayIntersect,
]
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,92 @@ def test_array_remove(self, dbe: DbEvaluator, data_table: sa.Table) -> None:
assert dbe.eval("ARR_REMOVE([arr_float_value], GET_ITEM([arr_float_value], 1))", from_=data_table) == dbe.eval(
"ARRAY(45, 0.123, NULL)"
)

def test_array_intersection(self, dbe: DbEvaluator, data_table: sa.Table) -> None:
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2))", from_=data_table) in (
dbe.eval("ARRAY(1, 2)"),
dbe.eval("ARRAY(2, 1)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2), ARRAY(1, 2))", from_=data_table) in (
dbe.eval("ARRAY(1, 2)"),
dbe.eval("ARRAY(2, 1)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2, 2), ARRAY(1, 2, 2))", from_=data_table) in (
dbe.eval("ARRAY(1, 2)"),
dbe.eval("ARRAY(2, 1)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2), ARRAY(3, 4), ARRAY(5, 6))", from_=data_table) in ([], "[]")
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2), ARRAY(2, 3), ARRAY(3, 4))", from_=data_table) in ([], "[]")
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2), ARRAY(2), ARRAY(2, 2, 3))", from_=data_table) == dbe.eval(
"ARRAY(2)"
)
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2, 3), ARRAY(2, 3, 4))", from_=data_table) in (
dbe.eval("ARRAY(2, 3)"),
dbe.eval("ARRAY(3, 2)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(2, 3), ARRAY(1, 2, 3, 4))", from_=data_table) in (
dbe.eval("ARRAY(2, 3)"),
dbe.eval("ARRAY(3, 2)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(2, 3, 2, 2, 4), ARRAY(1, 2, 3, 2), ARRAY(2, 3, 2))", from_=data_table) in (
dbe.eval("ARRAY(2, 3)"),
dbe.eval("ARRAY(3, 2)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(1, 2, 3, NULL), ARRAY(2, 3, 4))", from_=data_table) in (
dbe.eval("ARRAY(2, 3)"),
dbe.eval("ARRAY(3, 2)"),
)
assert dbe.eval(
"ARR_INTERSECT(ARRAY(0, 2, NULL, NULL), ARRAY(0, NULL), ARRAY(2, NULL, 0))", from_=data_table
) in (dbe.eval("ARRAY(0, NULL)"), dbe.eval("ARRAY(NULL, 0)"))
assert dbe.eval("ARR_INTERSECT(ARRAY(0, NULL, NULL), ARRAY(NULL, 0, NULL, NULL))", from_=data_table) in (
dbe.eval("ARRAY(0, NULL)"),
dbe.eval("ARRAY(NULL, 0)"),
)

assert dbe.eval("ARR_INTERSECT(ARRAY(0, 5, 4.999), ARRAY(0, 5.0), ARRAY(4.999))", from_=data_table) in (
[],
"[]",
)
if self.make_decimal_cast:
assert dbe.eval("ARR_INTERSECT(ARRAY(5, 49.999), ARRAY(5, 49.999))", from_=data_table) in (
dbe.eval(
f'ARRAY(DB_CAST(5.0, "{self.make_decimal_cast}", 2, 1), DB_CAST(49.999, "{self.make_decimal_cast}", 5, 3))'
),
dbe.eval(
f'ARRAY(DB_CAST(49.999, "{self.make_decimal_cast}", 5, 3), DB_CAST(5.0, "{self.make_decimal_cast}", 2, 1))'
),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(0, 5, 4.999), ARRAY(0, 5.0))", from_=data_table) in (
dbe.eval(
f'ARRAY(DB_CAST(0.0, "{self.make_decimal_cast}", 2, 1), DB_CAST(5.0, "{self.make_decimal_cast}", 2, 1))'
),
dbe.eval(
f'ARRAY(DB_CAST(5.0, "{self.make_decimal_cast}", 2, 1), DB_CAST(0.0, "{self.make_decimal_cast}", 2, 1))'
),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(0, 5, 4.999), ARRAY(4.999))", from_=data_table) == dbe.eval(
f'ARRAY(DB_CAST(4.999, "{self.make_decimal_cast}", 4, 3))'
)
else:
assert dbe.eval("ARR_INTERSECT(ARRAY(5, 49.999), ARRAY(5, 49.999))", from_=data_table) in (
dbe.eval("ARRAY(5.0, 49.999)"),
dbe.eval("ARRAY(49.999, 5.0)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(0, 5, 4.999), ARRAY(0, 5.0))", from_=data_table) in (
dbe.eval("ARRAY(0, 5.0)"),
dbe.eval("ARRAY(5.0, 0)"),
)
assert dbe.eval("ARR_INTERSECT(ARRAY(0, 5, 4.999), ARRAY(4.999))", from_=data_table) == dbe.eval(
"ARRAY(4.999)"
)

assert dbe.eval('ARR_INTERSECT(ARRAY("a", "b", "c"), ARRAY("abc"))', from_=data_table) in ([], "[]")
assert dbe.eval('ARR_INTERSECT(ARRAY("a", "b", "c"), ARRAY("a", "bc"))', from_=data_table) == dbe.eval(
'ARRAY("a")'
)
assert dbe.eval('ARR_INTERSECT(ARRAY("cba"), ARRAY("abc"))', from_=data_table) in ([], "[]")
assert dbe.eval(
'ARR_INTERSECT(ARRAY("ab", "c", "c"), ARRAY("ab", "b", "c", "c"), ARRAY("a", "c", "c", "ab"))',
from_=data_table,
) in (dbe.eval('ARRAY("ab", "c")'), dbe.eval('ARRAY("c", "ab")'))

0 comments on commit e37c934

Please sign in to comment.