Skip to content

Commit

Permalink
fix numpy 2.2 change in annotated behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
skrawcz authored and elijahbenizzy committed Dec 12, 2024
1 parent 1c176f1 commit fc239a9
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def fn() -> int:
assert node_copy_copy.name == "rename_fn_again"


np_version = np.__version__
major, minor, _ = map(int, np_version.split("."))


@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
def test_node_handles_annotated():
from typing import Annotated
Expand All @@ -56,15 +60,24 @@ def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.f

node = Node.from_fn(annotated_func)
assert node.name == "annotated_func"
expected = {
"first": (
Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]],
DependencyType.REQUIRED,
),
"other": (float, DependencyType.OPTIONAL),
}
if major == 2 and minor > 1: # greater that 2.1
expected = {
"first": (
Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]],
DependencyType.REQUIRED,
),
"other": (float, DependencyType.OPTIONAL),
}
else:
expected = {
"first": (
Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]],
DependencyType.REQUIRED,
),
"other": (float, DependencyType.OPTIONAL),
}
assert node.input_types == expected
assert node.type == Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]]
assert node.type == Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]]


@pytest.mark.parametrize(
Expand Down

0 comments on commit fc239a9

Please sign in to comment.