Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add '@auto_record_function' for less verbose record/replay annotaitons #240

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20250124-164923.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add '@auto_record_function' for less verbose record/replay annotaitons
time: 2025-01-24T16:49:23.097806-05:00
custom:
Author: peterallenwebb
Issue: "240"
162 changes: 104 additions & 58 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os

from enum import Enum
from inspect import getfullargspec, signature, FullArgSpec
from typing import Any, Callable, Dict, List, Mapping, Optional, Type
import contextvars

Expand Down Expand Up @@ -294,70 +295,115 @@ def get_record_types_from_dict(fp: str) -> List:
return list(loaded_dct.keys())


def auto_record_function(
record_name: str, method: bool = False, group_name: Optional[str] = None
) -> Callable:
return functools.partial(_record_function_inner, record_name, method, False, None)


def record_function(
record_type,
method: bool = False,
tuple_result: bool = False,
id_field_name: Optional[str] = None,
) -> Callable:
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
if get_record_mode_from_env() is None:
return func_to_record

@functools.wraps(func_to_record)
def record_replay_wrapper(*args, **kwargs) -> Any:
recorder: Optional[Recorder] = None
try:
from dbt_common.context import get_invocation_context

recorder = get_invocation_context().recorder
except LookupError:
pass

if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.recorded_types is not None and not (
record_type.__name__ in recorder.recorded_types
or record_type.group in recorder.recorded_types
):
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
param_args = (getattr(args[0], id_field_name),) + param_args

params = record_type.params_cls(*param_args, **kwargs)

include = True
if hasattr(params, "_include"):
include = params._include()

if not include:
return func_to_record(*args, **kwargs)

if recorder.mode == RecorderMode.REPLAY:
return recorder.expect_record(params)
if RECORDED_BY_HIGHER_FUNCTION.get():
return func_to_record(*args, **kwargs)

RECORDED_BY_HIGHER_FUNCTION.set(True)
r = func_to_record(*args, **kwargs)
result = (
None
if record_type.result_cls is None
else record_type.result_cls(*r)
if tuple_result
else record_type.result_cls(r)
return functools.partial(
_record_function_inner, record_type, method, tuple_result, id_field_name
)


def get_arg_fields(spec: FullArgSpec):
arg_fields = []
defaults = len(spec.defaults) if spec.defaults else 0
for i, arg in enumerate(spec.args):
annotation = spec.annotations.get(arg)
if i >= len(spec.args) - defaults:
arg_fields.append(
(
arg,
annotation,
dataclasses.field(
default=spec.defaults[i - len(spec.args) + defaults]
if spec.defaults
else None
), # type: ignore
)
)
RECORDED_BY_HIGHER_FUNCTION.set(False)
recorder.add_record(record_type(params=params, result=result))
return r
else:
arg_fields.append((arg, annotation, None))

return arg_fields

return record_replay_wrapper

return record_function_inner
def _record_function_inner(record_type, method, tuple_result, id_field_name, func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
if get_record_mode_from_env() is None:
return func_to_record

if isinstance(record_type, str):
return_type = signature(func_to_record).return_annotation
params_cls = dataclasses.make_dataclass(
f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record))
)
result_cls = dataclasses.make_dataclass(
f"{record_type}Result", [("return_val", return_type)]
)

record_type = type(
f"{record_type}Record", (Record,), {"params_cls": params_cls, "result_cls": result_cls}
)

@functools.wraps(func_to_record)
def record_replay_wrapper(*args, **kwargs) -> Any:
recorder: Optional[Recorder] = None
try:
from dbt_common.context import get_invocation_context

recorder = get_invocation_context().recorder
except LookupError:
pass

if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.recorded_types is not None and not (
record_type.__name__ in recorder.recorded_types
or record_type.group in recorder.recorded_types
):
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
param_args = (getattr(args[0], id_field_name),) + param_args

params = record_type.params_cls(*param_args, **kwargs)

include = True
if hasattr(params, "_include"):
include = params._include()

if not include:
return func_to_record(*args, **kwargs)

if recorder.mode == RecorderMode.REPLAY:
return recorder.expect_record(params)
if RECORDED_BY_HIGHER_FUNCTION.get():
return func_to_record(*args, **kwargs)

RECORDED_BY_HIGHER_FUNCTION.set(True)
r = func_to_record(*args, **kwargs)
result = (
None
if record_type.result_cls is None
else record_type.result_cls(*r)
if tuple_result
else record_type.result_cls(r)
)
RECORDED_BY_HIGHER_FUNCTION.set(False)
recorder.add_record(record_type(params=params, result=result))
return r

return record_replay_wrapper
19 changes: 18 additions & 1 deletion tests/unit/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

from dbt_common.context import set_invocation_context, get_invocation_context
from dbt_common.record import record_function, Record, Recorder, RecorderMode
from dbt_common.record import record_function, Record, Recorder, RecorderMode, auto_record_function


@dataclasses.dataclass
Expand Down Expand Up @@ -199,3 +199,20 @@ def outer_func(a: int, b: str) -> str:

result = outer_func(123, "abc")
assert result == "123abc124abc"


def test_auto_decorator_records(setup) -> None:
os.environ["DBT_RECORDER_MODE"] = "Record"
recorder = Recorder(RecorderMode.RECORD, None)
set_invocation_context({})
get_invocation_context().recorder = recorder

@auto_record_function("TestAuto")
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
return str(a) + b + (c if c else "")

test_func(123, "abc")

assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 123
assert recorder._records_by_type["TestAutoRecord"][-1].params.b == "abc"
assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == "123abc"
Loading