From 0bd4705cd2e3373911657f7449550b7b6621fe17 Mon Sep 17 00:00:00 2001 From: Seung Hyun Kim Date: Wed, 26 Jun 2024 00:05:16 -0500 Subject: [PATCH] type: remove None-type for connection index for simplicity --- elastica/modules/base_system.py | 47 ++++++++++++++++++++++---- elastica/modules/connections.py | 23 +++++++------ elastica/typing.py | 5 +-- tests/test_modules/test_connections.py | 5 +-- 4 files changed, 59 insertions(+), 21 deletions(-) diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index 43666bf6..a7978ca9 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -144,6 +144,24 @@ def override_allowed_types( def get_system_index( self, system: "SystemType | StaticSystemType" ) -> SystemIdxType: + """ + Get the index of the system object in the system list. + System list is private, so this is the only way to get the index of the system object. + + Example + ------- + >>> system_collection: SystemCollectionProtocol + >>> system: SystemType + ... + >>> system_idx = system_collection.get_system_index(system) # save idx + ... + >>> system = system_collection[system_idx] # just need idx to retrieve + + Parameters + ---------- + system: SystemType + System object to be found in the system list. + """ n_systems = len(self) # Total number of systems from mixed-in class sys_idx: SystemIdxType @@ -170,13 +188,18 @@ def get_system_index( @final def systems(self) -> Generator[StaticSystemType, None, None]: - # assert self._finalize_flag, "The simulator is not finalized." + """ + Iterate over all systems in the system collection. + If the system collection is finalized, block objects are also included. + """ for system in self.__systems: yield system @final def block_systems(self) -> Generator[BlockSystemType, None, None]: - # assert self._finalize_flag, "The simulator is not finalized." + """ + Iterate over all block systems in the system collection. + """ for block in self.__final_blocks: yield block @@ -208,24 +231,36 @@ def finalize(self) -> None: @final def synchronize(self, time: np.float64) -> None: - # Collection call _feature_group_synchronize + """ + Call synchronize functions for all features. + Features are registered in _feature_group_synchronize. + """ for func in self._feature_group_synchronize: func(time=time) @final def constrain_values(self, time: np.float64) -> None: - # Collection call _feature_group_constrain_values + """ + Call constrain values functions for all features. + Features are registered in _feature_group_constrain_values. + """ for func in self._feature_group_constrain_values: func(time=time) @final def constrain_rates(self, time: np.float64) -> None: - # Collection call _feature_group_constrain_rates + """ + Call constrain rates functions for all features. + Features are registered in _feature_group_constrain_rates. + """ for func in self._feature_group_constrain_rates: func(time=time) @final def apply_callbacks(self, time: np.float64, current_step: int) -> None: - # Collection call _feature_group_callback + """ + Call callback functions for all features. + Features are registered in _feature_group_callback. + """ for func in self._feature_group_callback: func(time=time, current_step=current_step) diff --git a/elastica/modules/connections.py b/elastica/modules/connections.py index aa1f63c8..99f1c9ae 100644 --- a/elastica/modules/connections.py +++ b/elastica/modules/connections.py @@ -42,8 +42,8 @@ def connect( self: SystemCollectionProtocol, first_rod: "RodType | RigidBodyType", second_rod: "RodType | RigidBodyType", - first_connect_idx: ConnectionIndex = None, - second_connect_idx: ConnectionIndex = None, + first_connect_idx: ConnectionIndex = (), + second_connect_idx: ConnectionIndex = (), ) -> ModuleProtocol: """ This method connects two rod-like objects using the selected joint class. @@ -56,9 +56,9 @@ def connect( Rod-like object second_rod : RodType | RigidBodyType Rod-like object - first_connect_idx : Optional[int] + first_connect_idx : ConnectionIndex Index of first rod for joint. - second_connect_idx : Optional[int] + second_connect_idx : ConnectionIndex Index of second rod for joint. Returns @@ -173,8 +173,8 @@ def __init__( self._second_sys_idx: SystemIdxType = second_sys_idx self._first_sys_n_lim: int = first_sys_nlim self._second_sys_n_lim: int = second_sys_nlim - self.first_sys_connection_idx: ConnectionIndex = None - self.second_sys_connection_idx: ConnectionIndex = None + self.first_sys_connection_idx: ConnectionIndex = () + self.second_sys_connection_idx: ConnectionIndex = () self._connect_cls: Type[FreeJoint] def set_index( @@ -188,7 +188,7 @@ def set_index( ), f"Type of first_connect_idx :{first_type} is different than second_connect_idx :{second_type}" # Check if the type of idx variables are correct. - allow_types = (int, np.int_, list, tuple, np.ndarray, type(None)) + allow_types = (int, np.int_, list, tuple, np.ndarray) assert isinstance( first_idx, allow_types ), f"Connection index type is not supported :{first_type}, please try one of the following :{allow_types}" @@ -227,10 +227,7 @@ def set_index( ), "Connection index of second rod exceeds its dof : {}".format( self._second_sys_n_lim ) - elif first_idx is None: - # Do nothing if idx are None - pass - else: + elif isinstance(first_idx, (int, np.int_)): # The addition of +1 and and <= check on the RHS is because # connections can be made to the node indices as well first_idx__ = cast(int, first_idx) @@ -245,6 +242,10 @@ def set_index( ), "Connection index of second rod exceeds its dof : {}".format( self._second_sys_n_lim ) + else: + raise TypeError( + "Connection index type is not supported :{}".format(first_type) + ) self.first_sys_connection_idx = first_idx self.second_sys_connection_idx = second_idx diff --git a/elastica/typing.py b/elastica/typing.py index 1ba1a4ee..d05c4fea 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -57,12 +57,13 @@ # Indexing types # TODO: Maybe just use slice?? -ConstrainingIndex: TypeAlias = list[int] | tuple[int, ...] | np.typing.NDArray +ConstrainingIndex: TypeAlias = tuple[int, ...] ConnectionIndex: TypeAlias = ( - int | np.int_ | list[int] | tuple[int] | np.typing.NDArray | None + int | np.int_ | list[int] | tuple[int, ...] | np.typing.NDArray[np.int32] ) # Operators in elastica.modules +# TODO: can be more specific. OperatorParam = ParamSpec("OperatorParam") OperatorCallbackType: TypeAlias = Callable[..., None] OperatorFinalizeType: TypeAlias = Callable[..., None] diff --git a/tests/test_modules/test_connections.py b/tests/test_modules/test_connections.py index 26f0cb24..c63f8bfb 100644 --- a/tests/test_modules/test_connections.py +++ b/tests/test_modules/test_connections.py @@ -108,7 +108,7 @@ def test_set_index_with_illegal_type_second_idx_throws( # Below test is to increase code coverage. If we pass nothing or idx=None, then do nothing. def test_set_index_no_input(self, load_connect): - load_connect.set_index(first_idx=None, second_idx=None) + load_connect.set_index(first_idx=(), second_idx=()) @pytest.mark.parametrize( "legal_idx", [(80, 80), (0, 50), (50, 0), (-20, -20), (-20, 50), (-50, -20)] @@ -291,7 +291,8 @@ def test_connect_registers_and_returns_Connect(self, load_system_with_connects): assert _mock_connect in system_collection_with_connections._connections assert _mock_connect.__class__ == _Connect # check sane defaults provided for connection indices - assert _mock_connect.id()[2] is None and _mock_connect.id()[3] is None + assert _mock_connect.id()[2] == () + assert _mock_connect.id()[3] == () from elastica.joint import FreeJoint