Skip to content

Commit

Permalink
Add a wrapper for single-arg methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jsoucheiron committed Jul 16, 2024
1 parent ceee567 commit 64c24a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
16 changes: 15 additions & 1 deletion cfripper/config/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def wrap(*args, **kwargs):

return wrap

def single_param_resolver(f):
def wrap(*args, **kwargs):
calculated_parameters = [arg(kwargs) for arg in args]
if len(calculated_parameters) == 1 and isinstance(calculated_parameters[0], (dict, set)):
result = f(*calculated_parameters, **kwargs)
else:
result = f(calculated_parameters, **kwargs)
if debug:
logger.debug(f"{function_name}({', '.join(str(x) for x in calculated_parameters)}) -> {result}")
return result

return wrap

implemented_filter_functions = {
"and": lambda *args, **kwargs: all(arg(kwargs) for arg in args),
"empty": param_resolver(lambda *args, **kwargs: len(args) == 0),
Expand All @@ -57,7 +70,8 @@ def wrap(*args, **kwargs):
"ref": param_resolver(lambda param_name, **kwargs: get(kwargs, param_name)),
"regex": param_resolver(lambda *args, **kwargs: bool(re.match(*args))),
"regex:ignorecase": param_resolver(lambda *args, **kwargs: bool(re.match(*args, re.IGNORECASE))),
"set": lambda *args, **kwargs: {arg(kwargs) for arg in args},
"set": single_param_resolver(lambda *args, **kwargs: set(*args)),
"sorted": single_param_resolver(lambda *args, **kwargs: sorted(*args)),
}
return implemented_filter_functions[function_name]

Expand Down
35 changes: 18 additions & 17 deletions tests/config/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,26 @@ def template_security_group_firehose_ips():
(Filter(eval={"ref": "param_a.param_b.param_c"}), {"param_a": {"param_b": {"param_c": [1]}}}, [1]),
(Filter(eval={"ref": "param_a.param_b.param_c"}), {"param_a": {"param_b": {"param_c": [-1]}}}, [-1]),
(Filter(eval={"ref": "param_a.param_b.param_c"}), {"param_a": {"param_b": {"param_c": [1.0]}}}, [1.0]),
(Filter(eval={"set": []}), {}, set()),
(Filter(eval={"set": {}}), {}, set()),
(Filter(eval={"set": set()}), {}, set()),
(Filter(eval={"set": {"80"}}), {}, {"80"}),
(Filter(eval={"set": ["80"]}), {}, {"80"}),
(Filter(eval={"set": {"80": 100}}), {}, {"80"}),
(Filter(eval={"set": {"80": 100, "90": 100}}), {}, {"80", "90"}),
(Filter(eval={"set": ["80", "443"]}), {}, {"80", "443"}),
(Filter(eval={"set": {"80", "443"}}), {}, {"80", "443"}),
(Filter(eval={"set": ["80", "443", "8080"]}), {}, {"80", "443", "8080"}),
(Filter(eval={"sorted": []}), {}, []),
(Filter(eval={"sorted": {}}), {}, []),
(Filter(eval={"sorted": set()}), {}, []),
(Filter(eval={"sorted": {"80"}}), {}, ["80"]),
(Filter(eval={"sorted": ["80"]}), {}, ["80"]),
(Filter(eval={"sorted": {"80": 100}}), {}, ["80"]),
(Filter(eval={"sorted": {"80": 100, "90": 100}}), {}, ["80", "90"]),
(Filter(eval={"sorted": ["80", "443"]}), {}, ["443", "80"]),
(Filter(eval={"sorted": {"80", "443"}}), {}, ["443", "80"]),
(Filter(eval={"sorted": ["80", "443", "8080"]}), {}, ["443", "80", "8080"]),
# Composed
(Filter(eval={"eq": [{"ref": "param_a"}, "a"]}), {"param_a": "a"}, True),
(Filter(eval={"eq": ["a", {"ref": "param_a"}]}), {"param_a": "a"}, True),
Expand All @@ -247,23 +265,6 @@ def test_filter(filter_name, args, expected_result):
assert filter_name(**args) == expected_result


@pytest.mark.parametrize(
"filter_name, args, expected_result",
[
(Filter(eval={"set": []}), {}, set()),
(Filter(eval={"set": {"80"}}), {}, {"80"}),
(Filter(eval={"set": ["80"]}), {}, {"80"}),
(Filter(eval={"set": {"80": 100}}), {}, {"80"}),
(Filter(eval={"set": {"80": 100, "90": 100}}), {}, {"80", "90"}),
(Filter(eval={"set": ["80", "443"]}), {}, {"80", "443"}),
(Filter(eval={"set": {"80", "443"}}), {}, {"80", "443"}),
(Filter(eval={"set": ["80", "443", "8080"]}), {}, {"80", "443", "8080"}),
],
)
def test_filter_set(filter_name, args, expected_result):
assert filter_name(**args) == expected_result


def test_exist_function_and_property_does_not_exist(template_cross_account_role_no_name):
mock_config = Config(
rules=["CrossAccountTrustRule"],
Expand Down

0 comments on commit 64c24a1

Please sign in to comment.