Skip to content

Commit

Permalink
Merge pull request #427 from Point72/nk/fix_425_38
Browse files Browse the repository at this point in the history
Fix #425
  • Loading branch information
NeejWeej authored Jan 24, 2025
2 parents ca238be + e48c2a7 commit f376bd6
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 2 deletions.
10 changes: 9 additions & 1 deletion csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,15 @@ def _add_scalar_value(self, arg, in_out_def):
def _is_scalar_value_matching_spec(self, inp_def_type, arg):
if inp_def_type is typing.Any:
return True
if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type:
if CspTypingUtils.is_callable(inp_def_type):
return callable(arg)
resolved_type = UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False)
if resolved_type is inp_def_type:
return True
elif (
CspTypingUtils.is_generic_container(inp_def_type)
and CspTypingUtils.get_orig_base(inp_def_type) is resolved_type
):
return True
if CspTypingUtils.is_union_type(inp_def_type):
types = inp_def_type.__args__
Expand Down
19 changes: 18 additions & 1 deletion csp/impl/types/typing_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# utils for dealing with typing types
import collections
import numpy
import sys
import types
Expand Down Expand Up @@ -29,6 +30,15 @@ def get_origin(cls, typ):
raw_origin = typ.__origin__
return cls._ORIGIN_COMPAT_MAP.get(raw_origin, raw_origin)

@classmethod
def is_callable(cls, typ):
# Checks if a type annotation refers to a callable
if typ is typing.Callable:
return True
if not hasattr(typ, "__origin__"):
return False
return CspTypingUtils.get_origin(typ) is collections.abc.Callable

@classmethod
def is_numpy_array_type(cls, typ):
return CspTypingUtils.is_generic_container(typ) and CspTypingUtils.get_orig_base(typ) is numpy.ndarray
Expand All @@ -40,7 +50,10 @@ def is_numpy_nd_array_type(cls, typ):
# is typ a standard generic container
@classmethod
def is_generic_container(cls, typ):
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
# isinstance(typing.Callable, typing._GenericAlias) passses in python 3.8, we don't want that
return (
isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union and typ is not typing.Callable
)

@classmethod
def is_union_type(cls, typ):
Expand Down Expand Up @@ -77,6 +90,10 @@ class CspTypingUtils39(CspTypingUtils37):
# To support PEP 585
_GENERIC_ALIASES = (typing._GenericAlias, typing.GenericAlias)

@classmethod
def is_generic_container(cls, typ):
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union

CspTypingUtils = CspTypingUtils39

if sys.version_info >= (3, 10):
Expand Down
267 changes: 267 additions & 0 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import typing
import unittest
from datetime import datetime, time, timedelta
from typing import Callable, Dict, List, Optional, Union

import csp
import csp.impl.types.instantiation_type_resolver as type_resolver
from csp import ts
from csp.impl.types.typing_utils import CspTypingUtils
from csp.impl.wiring.runtime import build_graph

USE_PYDANTIC = os.environ.get("CSP_PYDANTIC")
Expand Down Expand Up @@ -621,6 +623,46 @@ def main():

csp.run(main, starttime=datetime.utcnow(), endtime=timedelta())

def test_typed_to_untyped_container_wrong(self):
@csp.graph
def g1(d: csp.ts[dict]):
pass

@csp.graph
def g2(d: csp.ts[set]):
pass

@csp.graph
def g3(d: csp.ts[list]):
pass

def main():
# This should fail - wrong key type in Dict
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(dict\\)\\(T=typing\\.Dict\\[int, int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g1(d=csp.const.using(T=typing.Dict[int, int])({"a": 10}))

# This should fail - wrong element type in Set
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(set\\)\\(T=typing\\.Set\\[int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g2(d=csp.const.using(T=typing.Set[int])(set(["z"])))

# This should fail - wrong element type in List
if USE_PYDANTIC:
msg = "(?s)1 validation error for csp.const.*Input should be a valid integer \\[type=int_type"
else:
msg = "In function csp\\.const: Expected ~T for argument 'value', got .* \\(list\\)\\(T=typing\\.List\\[int\\]\\)"
with self.assertRaisesRegex(TypeError, msg):
g3(d=csp.const.using(T=typing.List[int])(["d"]))

csp.run(main, starttime=datetime.utcnow(), endtime=timedelta())

def test_time_tzinfo(self):
import pytz

Expand Down Expand Up @@ -670,6 +712,231 @@ def g():
self.assertEqual(res["y"][0][1], set())
self.assertEqual(res["z"][0][1], {})

def test_callable_type_checking(self):
@csp.node
def node_callable_typed(x: ts[int], my_data: Callable[[int], int]) -> ts[int]:
if csp.ticked(x):
if my_data:
return my_data(x) if callable(my_data) else 12

@csp.node
def node_callable_untyped(x: ts[int], my_data: Callable) -> ts[int]:
if csp.ticked(x):
if my_data:
return my_data(x) if callable(my_data) else 12

def graph():
# These should work
node_callable_untyped(csp.const(10), lambda x: 2 * x)
node_callable_typed(csp.const(10), lambda x: x + 1)

# We intentionally allow setting None to be allowed
node_callable_typed(csp.const(10), None)
node_callable_untyped(csp.const(10), None)

# Here the Callable's type hints don't match the signature
# but we allow anyways, both with the pydantic version and without
node_callable_typed(csp.const(10), lambda x, y: "a")
node_callable_untyped(csp.const(10), lambda x, y: "a")

# This should fail - passing non-callable
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_callable_untyped.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_callable_untyped: Expected typing\\.Callable for argument 'my_data', got 11 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_callable_untyped(csp.const(10), 11)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_optional_type_checking(self):
for use_dict in [True, False]:
if use_dict:

@csp.node
def node_optional_list_typed(x: ts[int], my_data: Optional[Dict[int, int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

@csp.node
def node_optional_list_untyped(x: ts[int], my_data: Optional[dict] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x
else:

@csp.node
def node_optional_list_typed(x: ts[int], my_data: Optional[List[int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

@csp.node
def node_optional_list_untyped(x: ts[int], my_data: Optional[list] = None) -> ts[int]:
if csp.ticked(x):
return my_data[0] if my_data else x

def graph():
# Optional[list] tests - these should work
node_optional_list_untyped(csp.const(10), {} if use_dict else [])
node_optional_list_untyped(csp.const(10), None)
node_optional_list_untyped(csp.const(10), {9: 10} if use_dict else [9])

# Optional[List[int]] tests
node_optional_list_typed(csp.const(10), None)
node_optional_list_typed(csp.const(10), {} if use_dict else [])
node_optional_list_typed(csp.const(10), {9: 10} if use_dict else [9])

# Here the List/Dict type hints don't match the signature
# But, for backwards compatibility (as this was the behavior with Optional in version 0.0.5)
# The pydantic version of the checks, however, catches this.
if USE_PYDANTIC:
msg = "(?s).*validation error.* for node_optional_list_typed.*my_data.*Input should be a valid integer.*type=int_parsing"
with self.assertRaisesRegex(TypeError, msg):
node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"])
else:
node_optional_list_typed(csp.const(10), {"a": "b"} if use_dict else ["a"])

# This should fail - type mismatch
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_list_typed.*my_data"
else:
msg = "In function node_optional_list_typed: Expected typing\\.(?:Optional\\[typing|Union\\[typing)\\..*"
with self.assertRaisesRegex(TypeError, msg):
node_optional_list_typed(csp.const(10), [] if use_dict else {})

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_optional_callable_type_checking(self):
@csp.node
def node_optional_callable_typed(x: ts[int], my_data: Optional[Callable[[int], int]] = None) -> ts[int]:
if csp.ticked(x):
return my_data(x) if my_data else x

@csp.node
def node_optional_callable_untyped(x: ts[int], my_data: Optional[Callable] = None) -> ts[int]:
if csp.ticked(x):
return my_data(x) if my_data else x

def graph():
# These should work for both typed and untyped
node_optional_callable_typed(csp.const(10), None)
node_optional_callable_untyped(csp.const(10), None)

# These should also work - valid callables
node_optional_callable_typed(csp.const(10), lambda x: x + 1)
node_optional_callable_untyped(csp.const(10), lambda x: 2 * x)

# Here the Callable's type hints don't match the signature
# but we allow anyways, both with the pydantic version and without
node_optional_callable_typed(csp.const(10), lambda x, y: "a")
node_optional_callable_untyped(csp.const(10), lambda x, y: "a")

# This should fail - passing non-callable to typed version
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_callable_typed.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_optional_callable_typed: Expected typing\\.(?:Optional\\[typing\\.Callable\\[\\[int\\], int\\]\\]|Union\\[typing\\.Callable\\[\\[int\\], int\\], NoneType\\]) for argument 'my_data', got 12 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_optional_callable_typed(csp.const(10), 12)

# This should fail - passing non-callable to typed version
if USE_PYDANTIC:
msg = "(?s)1 validation error for node_optional_callable_typed.*my_data.*Input should be callable \\[type=callable_type"
else:
msg = "In function node_optional_callable_typed: Expected typing\\.(?:Optional\\[typing\\.Callable\\[\\[int\\], int\\]\\]|Union\\[typing\\.Callable\\[\\[int\\], int\\], NoneType\\]) for argument 'my_data', got 12 \\(int\\)"
with self.assertRaisesRegex(TypeError, msg):
node_optional_callable_typed(csp.const(10), 12)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_union_type_checking(self):
@csp.node
def node_union_typed(x: ts[int], my_data: Union[int, str]) -> ts[int]:
if csp.ticked(x):
return x + int(my_data) if isinstance(my_data, str) else x + my_data

def graph():
# These should work - valid int inputs
node_union_typed(csp.const(10), 5)

# These should also work - valid str inputs
node_union_typed(csp.const(10), "123")

# These should fail - passing float when expecting Union[int, str]
if USE_PYDANTIC:
msg = "(?s)2 validation errors for node_union_typed.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part.*my_data\\.str.*Input should be a valid string"
else:
msg = "In function node_union_typed: Expected typing\\.Union\\[int, str\\] for argument 'my_data', got 12\\.5 \\(float\\)"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), 12.5)

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_union_list_type_checking(self):
@csp.node
def node_union_typed(x: ts[int], my_data: Union[List[str], int] = None) -> ts[int]:
if csp.ticked(x):
if isinstance(my_data, list):
return x + len(my_data)
return x + my_data

@csp.node
def node_union_untyped(x: ts[int], my_data: Union[list, int] = None) -> ts[int]:
if csp.ticked(x):
if isinstance(my_data, list):
return x + len(my_data)
return x + my_data

def graph():
# These should work - valid int inputs
node_union_typed(csp.const(10), 5)
node_union_untyped(csp.const(10), 42)

# These should work - valid list inputs
node_union_typed(csp.const(10), ["hello", "world"])
node_union_untyped(csp.const(10), ["hello", "world"])

# This should fail - passing float when expecting Union[List[str], int]
if USE_PYDANTIC:
msg = "(?s)2 validation errors for node_union_typed.*my_data\\.list.*Input should be a valid list.*my_data\\.int.*Input should be a valid integer, got a number with a fractional part"
else:
msg = "In function node_union_typed: Expected typing\\.Union\\[typing\\.List\\[str\\], int\\] for argument 'my_data', got 12\\.5 \\(float\\)"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), 12.5)

# This should fail - passing list with wrong element type
if USE_PYDANTIC:
msg = "(?s)3 validation errors for node_union_typed.*my_data\\.list\\[str\\]\\.0.*Input should be a valid string.*my_data\\.list\\[str\\]\\.1.*Input should be a valid string.*my_data\\.int.*Input should be a valid integer"
with self.assertRaisesRegex(TypeError, msg):
node_union_typed(csp.const(10), [1, 2]) # List of ints instead of strings
else:
# We choose to intentionally not enforce the types provided
# to maintain previous flexibility when not using pydantic type validation
node_union_typed(csp.const(10), [1, 2])

node_union_untyped(csp.const(10), [1, 2])

csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))

def test_is_callable(self):
"""Test CspTypingUtils.is_callable with various input types"""
# Test cases as (input, expected_result) pairs
test_cases = [
# Direct Callable types
(Callable, True),
(Callable[[int, str], bool], True),
(Callable[..., None], True),
(Callable[[int], str], True),
# optional Callable is not Callable
(Optional[Callable], False),
# Typing module types
(List[int], False),
(Dict[str, int], False),
(typing.Set[str], False),
]
for input_type, expected in test_cases:
result = CspTypingUtils.is_callable(input_type)
self.assertEqual(result, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit f376bd6

Please sign in to comment.