diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a701a1046..b562a41370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ Changes can also be flagged with a GitHub label for tracking purposes. The URL o - Moved non-prod Admin UI dependencies to devDependencies [#5832](https://github.com/ethyca/fides/pull/5832) - Prevent Admin UI and Privacy Center from starting when running `nox -s dev` with datastore params [#5843](https://github.com/ethyca/fides/pull/5843) - Remove plotly (unused package) to reduce fides image size [#5852](https://github.com/ethyca/fides/pull/5852) +- Fixed issue where the log_context decorator didn't support positional arguments [#5866](https://github.com/ethyca/fides/pull/5866) ### Fixed - Fixed pagination bugs on some tables [#5819](https://github.com/ethyca/fides/pull/5819) diff --git a/src/fides/api/util/logger_context_utils.py b/src/fides/api/util/logger_context_utils.py index bcdb2bf496..fac8d9101b 100644 --- a/src/fides/api/util/logger_context_utils.py +++ b/src/fides/api/util/logger_context_utils.py @@ -1,3 +1,4 @@ +import inspect from abc import abstractmethod from enum import Enum from functools import wraps @@ -81,12 +82,46 @@ def decorator(func: Callable) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: context = dict(additional_context) - # extract specified param values from kwargs + # extract specified param values from kwargs and args if capture_args: + # First, process kwargs as they're explicitly named for arg_name, context_name in capture_args.items(): if arg_name in kwargs: context[context_name.value] = kwargs[arg_name] + # Process args using signature binding for more robust parameter mapping + if args: + try: + # Get the signature and bind the arguments + sig = inspect.signature(func) + # This will map positional args to their parameter names correctly + bound_args = sig.bind_partial(*args, **kwargs) + + # Now we can iterate through the bound arguments + for param_name, arg_value in bound_args.arguments.items(): + # Only process if this parameter is in capture_args and wasn't already found in kwargs + if param_name in capture_args and param_name not in kwargs: + context_name = capture_args[param_name] + context[context_name.value] = arg_value + except TypeError: + # Handle the case where the arguments don't match the signature + pass + + # Handle default parameters that weren't provided in args or kwargs + if capture_args: + sig = inspect.signature(func) + for param_name, param in sig.parameters.items(): + # Check if parameter has a default value and is in capture_args + # and hasn't been processed yet (not in context) + if ( + param.default is not param.empty + and param_name in capture_args + and capture_args[param_name].value not in context + ): + context_name = capture_args[param_name] + context[context_name.value] = param.default + + # Process Contextualizable args for arg in args: if isinstance(arg, Contextualizable): arg_context = arg.get_log_context() diff --git a/tests/ops/util/test_logger_context_utils.py b/tests/ops/util/test_logger_context_utils.py index 5023a2b50f..26a21464f5 100644 --- a/tests/ops/util/test_logger_context_utils.py +++ b/tests/ops/util/test_logger_context_utils.py @@ -161,12 +161,119 @@ def func(other_param: str, task_id: str): logger.info("processing") return other_param, task_id - func("something", task_id="abc123") + func("something", "abc123") assert loguru_caplog.records[0].extra == { LoggerContextKeys.task_id.value: "abc123" } + def test_log_context_with_multiple_positional_captured_args(self, loguru_caplog): + """Test that multiple captured args work with positional arguments""" + + @log_context( + capture_args={ + "task_id": LoggerContextKeys.task_id, + "request_id": LoggerContextKeys.privacy_request_id, + } + ) + def func(other_param: str, task_id: str, request_id: str): + logger.info("processing") + return other_param, task_id, request_id + + # Pass all arguments as positional arguments + func("something", "abc123", "req456") + + assert loguru_caplog.records[0].extra == { + LoggerContextKeys.task_id.value: "abc123", + LoggerContextKeys.privacy_request_id.value: "req456", + } + + def test_log_context_with_mixed_positional_and_keyword_only_args( + self, loguru_caplog + ): + """Test that captured args work with functions that have a mix of positional and keyword-only arguments""" + + @log_context( + capture_args={ + "task_id": LoggerContextKeys.task_id, + "request_id": LoggerContextKeys.privacy_request_id, + } + ) + def func(task_id: str, *, request_id: str): + logger.info("processing") + return task_id, request_id + + # Pass task_id as positional and request_id as keyword (required) + func("abc123", request_id="req456") + + assert loguru_caplog.records[0].extra == { + LoggerContextKeys.task_id.value: "abc123", + LoggerContextKeys.privacy_request_id.value: "req456", + } + + def test_log_context_with_keyword_only_args(self, loguru_caplog): + """Test that captured args work with functions that have only keyword-only arguments""" + + @log_context( + capture_args={ + "task_id": LoggerContextKeys.task_id, + "request_id": LoggerContextKeys.privacy_request_id, + } + ) + def func(*, task_id: str, request_id: str): + logger.info("processing") + return task_id, request_id + + # All arguments must be passed as keywords + func(task_id="abc123", request_id="req456") + + assert loguru_caplog.records[0].extra == { + LoggerContextKeys.task_id.value: "abc123", + LoggerContextKeys.privacy_request_id.value: "req456", + } + + def test_log_context_with_default_parameters(self, loguru_caplog): + """Test that captured args work with functions that have default parameters""" + + @log_context( + capture_args={ + "task_id": LoggerContextKeys.task_id, + "request_id": LoggerContextKeys.privacy_request_id, + } + ) + def func(task_id: str = "default_task", request_id: str = "default_request"): + logger.info("processing") + return task_id, request_id + + # Call with no arguments - should use defaults + func() + + assert loguru_caplog.records[0].extra == { + LoggerContextKeys.task_id.value: "default_task", + LoggerContextKeys.privacy_request_id.value: "default_request", + } + + def test_log_context_with_overridden_default_parameters(self, loguru_caplog): + """Test that captured args work with functions where default parameters are overridden""" + + @log_context( + capture_args={ + "task_id": LoggerContextKeys.task_id, + "request_id": LoggerContextKeys.privacy_request_id, + } + ) + def func(task_id: str = "default_task", request_id: str = "default_request"): + logger.info("processing") + return task_id, request_id + + # Override only one default parameter + func(task_id="abc123") + + assert loguru_caplog.records[0].extra == { + LoggerContextKeys.task_id.value: "abc123", + LoggerContextKeys.privacy_request_id.value: "default_request", + } + class TestDetailFunctions: @pytest.fixture