From 5616f7eb047f1fd62abaef680f2f0d178f3a1297 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Fri, 24 Jan 2025 16:43:24 -0500 Subject: [PATCH 1/3] Add an auto_record_function annotation for less verbose annotations in record/replay. --- dbt_common/record.py | 148 ++++++++++++++++++++++---------------- tests/unit/test_record.py | 23 +++++- 2 files changed, 109 insertions(+), 62 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index fecbbb2..bca7818 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -11,7 +11,9 @@ import os from enum import Enum -from typing import Any, Callable, Dict, List, Mapping, Optional, Type +from inspect import getfullargspec, signature, FullArgSpec +import typing as tt +from typing import get_type_hints, Any, Callable, Dict, List, Mapping, Optional, Type import contextvars RECORDED_BY_HIGHER_FUNCTION = contextvars.ContextVar("RECORDED_BY_HIGHER_FUNCTION", default=False) @@ -294,70 +296,94 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) +def auto_record_function(record_name: str, method: bool = False, group_name: Optional[str] = None) -> Callable: + return functools.partial(_record_function_inner, record_name, method, False, None) + def record_function( record_type, method: bool = False, tuple_result: bool = False, id_field_name: Optional[str] = None, ) -> Callable: - def record_function_inner(func_to_record): - # To avoid runtime overhead and other unpleasantness, we only apply the - # record/replay decorator if a relevant env var is set. - if get_record_mode_from_env() is None: - return func_to_record - - @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs) -> Any: - recorder: Optional[Recorder] = None - try: - from dbt_common.context import get_invocation_context - - recorder = get_invocation_context().recorder - except LookupError: - pass - - if recorder is None: - return func_to_record(*args, **kwargs) - - if recorder.recorded_types is not None and not ( - record_type.__name__ in recorder.recorded_types - or record_type.group in recorder.recorded_types - ): - return func_to_record(*args, **kwargs) - - # For methods, peel off the 'self' argument before calling the - # params constructor. - param_args = args[1:] if method else args - if method and id_field_name is not None: - param_args = (getattr(args[0], id_field_name),) + param_args - - params = record_type.params_cls(*param_args, **kwargs) - - include = True - if hasattr(params, "_include"): - include = params._include() - - if not include: - return func_to_record(*args, **kwargs) - - if recorder.mode == RecorderMode.REPLAY: - return recorder.expect_record(params) - if RECORDED_BY_HIGHER_FUNCTION.get(): - return func_to_record(*args, **kwargs) - - RECORDED_BY_HIGHER_FUNCTION.set(True) - r = func_to_record(*args, **kwargs) - result = ( - None - if record_type.result_cls is None - else record_type.result_cls(*r) - if tuple_result - else record_type.result_cls(r) - ) - RECORDED_BY_HIGHER_FUNCTION.set(False) - recorder.add_record(record_type(params=params, result=result)) - return r - - return record_replay_wrapper + return functools.partial(_record_function_inner, record_type, method, tuple_result, id_field_name) + +def get_arg_fields(spec: FullArgSpec): + arg_fields = [] + defaults = len(spec.defaults) + for i, arg in enumerate(spec.args): + annotation = spec.annotations.get(arg) + if i >= len(spec.args) - defaults: + arg_fields.append((arg, annotation, dataclasses.field(default=spec.defaults[i - len(spec.args) + defaults]))) + else: + arg_fields.append((arg, annotation,None)) + + + return arg_fields + +def _record_function_inner(record_type, method, tuple_result, id_field_name, func_to_record): + # To avoid runtime overhead and other unpleasantness, we only apply the + # record/replay decorator if a relevant env var is set. + if get_record_mode_from_env() is None: + return func_to_record + + if isinstance(record_type, str): + return_type = signature(func_to_record).return_annotation + params_cls = dataclasses.make_dataclass(f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record))) + result_cls = dataclasses.make_dataclass(f"{record_type}Result", [("return_val", return_type)]) + + record_type = type(f"{record_type}Record", (Record,), + {"params_cls": params_cls, "result_cls": result_cls}) + + @functools.wraps(func_to_record) + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None + try: + from dbt_common.context import get_invocation_context + + recorder = get_invocation_context().recorder + except LookupError: + pass + + if recorder is None: + return func_to_record(*args, **kwargs) + + if recorder.recorded_types is not None and not ( + record_type.__name__ in recorder.recorded_types + or record_type.group in recorder.recorded_types + ): + return func_to_record(*args, **kwargs) + + # For methods, peel off the 'self' argument before calling the + # params constructor. + param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args + + params = record_type.params_cls(*param_args, **kwargs) + + include = True + if hasattr(params, "_include"): + include = params._include() + + if not include: + return func_to_record(*args, **kwargs) + + if recorder.mode == RecorderMode.REPLAY: + return recorder.expect_record(params) + if RECORDED_BY_HIGHER_FUNCTION.get(): + return func_to_record(*args, **kwargs) + + RECORDED_BY_HIGHER_FUNCTION.set(True) + r = func_to_record(*args, **kwargs) + result = ( + None + if record_type.result_cls is None + else record_type.result_cls(*r) + if tuple_result + else record_type.result_cls(r) + ) + RECORDED_BY_HIGHER_FUNCTION.set(False) + recorder.add_record(record_type(params=params, result=result)) + return r - return record_function_inner + return record_replay_wrapper \ No newline at end of file diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index c84c1ae..5c1ca6d 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -4,7 +4,7 @@ from typing import Optional from dbt_common.context import set_invocation_context, get_invocation_context -from dbt_common.record import record_function, Record, Recorder, RecorderMode +from dbt_common.record import record_function, Record, Recorder, RecorderMode, auto_record_function @dataclasses.dataclass @@ -199,3 +199,24 @@ def outer_func(a: int, b: str) -> str: result = outer_func(123, "abc") assert result == "123abc124abc" + + +def test_auto_decorator_records(setup) -> None: + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @auto_record_function("TestAuto") + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 123 + assert recorder._records_by_type["TestAutoRecord"][-1].params.b == "abc" + assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == "123abc" \ No newline at end of file From f492d9eb8e18537b181df79e2895f068b8a69c92 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Fri, 24 Jan 2025 16:50:12 -0500 Subject: [PATCH 2/3] Add changelog entry. --- .changes/unreleased/Features-20250124-164923.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20250124-164923.yaml diff --git a/.changes/unreleased/Features-20250124-164923.yaml b/.changes/unreleased/Features-20250124-164923.yaml new file mode 100644 index 0000000..7a88b05 --- /dev/null +++ b/.changes/unreleased/Features-20250124-164923.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add '@auto_record_function' for less verbose record/replay annotaitons +time: 2025-01-24T16:49:23.097806-05:00 +custom: + Author: peterallenwebb + Issue: "240" From c7b584a7e307ab3b2ddb4885c900dcf02fe66892 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Fri, 24 Jan 2025 17:10:48 -0500 Subject: [PATCH 3/3] Fix formatting, type annotations. --- dbt_common/record.py | 46 ++++++++++++++++++++++++++++----------- tests/unit/test_record.py | 6 +---- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index bca7818..d5ddcb3 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -12,8 +12,7 @@ from enum import Enum from inspect import getfullargspec, signature, FullArgSpec -import typing as tt -from typing import get_type_hints, Any, Callable, Dict, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, List, Mapping, Optional, Type import contextvars RECORDED_BY_HIGHER_FUNCTION = contextvars.ContextVar("RECORDED_BY_HIGHER_FUNCTION", default=False) @@ -296,30 +295,46 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) -def auto_record_function(record_name: str, method: bool = False, group_name: Optional[str] = None) -> Callable: +def auto_record_function( + record_name: str, method: bool = False, group_name: Optional[str] = None +) -> Callable: return functools.partial(_record_function_inner, record_name, method, False, None) + def record_function( record_type, method: bool = False, tuple_result: bool = False, id_field_name: Optional[str] = None, ) -> Callable: - return functools.partial(_record_function_inner, record_type, method, tuple_result, id_field_name) + return functools.partial( + _record_function_inner, record_type, method, tuple_result, id_field_name + ) + def get_arg_fields(spec: FullArgSpec): arg_fields = [] - defaults = len(spec.defaults) + defaults = len(spec.defaults) if spec.defaults else 0 for i, arg in enumerate(spec.args): annotation = spec.annotations.get(arg) if i >= len(spec.args) - defaults: - arg_fields.append((arg, annotation, dataclasses.field(default=spec.defaults[i - len(spec.args) + defaults]))) + arg_fields.append( + ( + arg, + annotation, + dataclasses.field( + default=spec.defaults[i - len(spec.args) + defaults] + if spec.defaults + else None + ), # type: ignore + ) + ) else: - arg_fields.append((arg, annotation,None)) - + arg_fields.append((arg, annotation, None)) return arg_fields + def _record_function_inner(record_type, method, tuple_result, id_field_name, func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -328,11 +343,16 @@ def _record_function_inner(record_type, method, tuple_result, id_field_name, fun if isinstance(record_type, str): return_type = signature(func_to_record).return_annotation - params_cls = dataclasses.make_dataclass(f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record))) - result_cls = dataclasses.make_dataclass(f"{record_type}Result", [("return_val", return_type)]) + params_cls = dataclasses.make_dataclass( + f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record)) + ) + result_cls = dataclasses.make_dataclass( + f"{record_type}Result", [("return_val", return_type)] + ) - record_type = type(f"{record_type}Record", (Record,), - {"params_cls": params_cls, "result_cls": result_cls}) + record_type = type( + f"{record_type}Record", (Record,), {"params_cls": params_cls, "result_cls": result_cls} + ) @functools.wraps(func_to_record) def record_replay_wrapper(*args, **kwargs) -> Any: @@ -386,4 +406,4 @@ def record_replay_wrapper(*args, **kwargs) -> Any: recorder.add_record(record_type(params=params, result=result)) return r - return record_replay_wrapper \ No newline at end of file + return record_replay_wrapper diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index 5c1ca6d..8a4656f 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -213,10 +213,6 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: test_func(123, "abc") - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 123 assert recorder._records_by_type["TestAutoRecord"][-1].params.b == "abc" - assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == "123abc" \ No newline at end of file + assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == "123abc"