Skip to content

Commit

Permalink
Update and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
topher-lo committed Nov 25, 2024
1 parent 6b29027 commit d572293
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
48 changes: 36 additions & 12 deletions tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,22 @@
import pytest

from tracecat.expressions.functions import (
# Core/Utils
_bool,
_build_safe_lambda,
# Math Operations
add,
# Logical Operations
and_,
# Misc
b64_to_str,
b64url_to_str,
# String Operations
capitalize,
cast,
# IP Address Operations
check_ip_version,
contains,
# Time/Date Operations
create_days,
create_hours,
create_minutes,
create_seconds,
create_weeks,
# Collection Operations
days_between,
# JSON Operations
deserialize_ndjson,
dict_keys,
dict_lookup,
Expand All @@ -53,6 +44,7 @@
greater_than,
greater_than_or_equal,
hours_between,
intersect,
ipv4_in_subnet,
ipv4_is_public,
ipv6_in_subnet,
Expand All @@ -75,7 +67,6 @@
or_,
pow,
prettify_json_str,
# Regular Expressions
regex_extract,
regex_match,
regex_not_match,
Expand All @@ -93,11 +84,9 @@
titleize,
to_datetime,
to_timestamp_str,
# today,
union,
unset_timezone,
uppercase,
# utcnow,
weeks_between,
zip_iterables,
)
Expand Down Expand Up @@ -817,3 +806,38 @@ def test_flatten(input_iterables: list, expected: list) -> None:
# Test with non-iterable input
with pytest.raises((TypeError, AttributeError)):
flatten(123) # type: ignore


@pytest.mark.parametrize(
"items,collection,python_lambda,expected",
[
([1, 2, 3], [2, 3, 4], None, [2, 3]),
# Empty intersection
([1, 2], [3, 4], None, []),
# Empty inputs
([], [1, 2], None, []),
([1, 2], [], None, []),
# Duplicate values
([1, 1, 2], [1, 2, 2], None, [1, 2]),
# String values
(["a", "b"], ["b", "c"], None, ["b"]),
# With lambda transformation
([1, 2, 3], [2, 4, 6], "lambda x: x * 2", [1, 2, 3]),
# Lambda with string manipulation
(
["hello", "world"],
["HELLO", "WORLD"],
"lambda x: x.upper()",
["hello", "world"],
),
# Complex objects
([(1, 2), (3, 4)], [(1, 2), (5, 6)], None, [(1, 2)]),
],
)
def test_intersect(
items: list, collection: list, python_lambda: str | None, expected: list
) -> None:
"""Test the intersect function with various inputs and transformations."""
result = intersect(items, collection, python_lambda)
# Sort the results to ensure consistent comparison
assert sorted(result) == sorted(expected)
17 changes: 9 additions & 8 deletions tracecat/expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,15 +662,16 @@ def or_(a: bool, b: bool) -> bool:


def intersect[T: Any](
items: Sequence[T], collection: Sequence[T], jsonpath: str | None = None
items: Sequence[T], collection: Sequence[T], python_lambda: str | None = None
) -> list[T]:
"""Return the set intersection of two sequences as a list."""
"""Return the set intersection of two sequences as a list. If a Python lambda is provided, it will be applied to each item before checking for intersection."""
col_set = set(collection)
if jsonpath:
return list(
{item for item in items if eval_jsonpath(jsonpath, item) in col_set}
)
return list({item for item in items if item in collection})
if python_lambda:
fn = _build_safe_lambda(python_lambda)
result = [item for item in items if fn(item) in col_set]
else:
result = set(items) & col_set
return list(result)


def union[T: Any](*collections: Sequence[T]) -> list[T]:
Expand All @@ -687,7 +688,7 @@ def apply[T: Any](item: T | Iterable[T], python_lambda: str) -> T | list[T]:


def filter_[T: Any](items: Sequence[T], python_lambda: str) -> list[T]:
"""Filter a collection using a Python lambda expression."""
"""Filter a collection using a Python lambda expression as a string (e.g. `"lambda x: x > 2"`)."""
fn = _build_safe_lambda(python_lambda)
return list(filter(fn, items))

Expand Down

0 comments on commit d572293

Please sign in to comment.