Skip to content

Commit

Permalink
type: remove None-type for connection index for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 26, 2024
1 parent 7039bab commit 0bd4705
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 21 deletions.
47 changes: 41 additions & 6 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ def override_allowed_types(
def get_system_index(
self, system: "SystemType | StaticSystemType"
) -> SystemIdxType:
"""
Get the index of the system object in the system list.
System list is private, so this is the only way to get the index of the system object.
Example
-------
>>> system_collection: SystemCollectionProtocol
>>> system: SystemType
...
>>> system_idx = system_collection.get_system_index(system) # save idx
...
>>> system = system_collection[system_idx] # just need idx to retrieve
Parameters
----------
system: SystemType
System object to be found in the system list.
"""
n_systems = len(self) # Total number of systems from mixed-in class

sys_idx: SystemIdxType
Expand All @@ -170,13 +188,18 @@ def get_system_index(

@final
def systems(self) -> Generator[StaticSystemType, None, None]:
# assert self._finalize_flag, "The simulator is not finalized."
"""
Iterate over all systems in the system collection.
If the system collection is finalized, block objects are also included.
"""
for system in self.__systems:
yield system

@final
def block_systems(self) -> Generator[BlockSystemType, None, None]:
# assert self._finalize_flag, "The simulator is not finalized."
"""
Iterate over all block systems in the system collection.
"""
for block in self.__final_blocks:
yield block

Expand Down Expand Up @@ -208,24 +231,36 @@ def finalize(self) -> None:

@final
def synchronize(self, time: np.float64) -> None:
# Collection call _feature_group_synchronize
"""
Call synchronize functions for all features.
Features are registered in _feature_group_synchronize.
"""
for func in self._feature_group_synchronize:
func(time=time)

@final
def constrain_values(self, time: np.float64) -> None:
# Collection call _feature_group_constrain_values
"""
Call constrain values functions for all features.
Features are registered in _feature_group_constrain_values.
"""
for func in self._feature_group_constrain_values:
func(time=time)

@final
def constrain_rates(self, time: np.float64) -> None:
# Collection call _feature_group_constrain_rates
"""
Call constrain rates functions for all features.
Features are registered in _feature_group_constrain_rates.
"""
for func in self._feature_group_constrain_rates:
func(time=time)

@final
def apply_callbacks(self, time: np.float64, current_step: int) -> None:
# Collection call _feature_group_callback
"""
Call callback functions for all features.
Features are registered in _feature_group_callback.
"""
for func in self._feature_group_callback:
func(time=time, current_step=current_step)
23 changes: 12 additions & 11 deletions elastica/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def connect(
self: SystemCollectionProtocol,
first_rod: "RodType | RigidBodyType",
second_rod: "RodType | RigidBodyType",
first_connect_idx: ConnectionIndex = None,
second_connect_idx: ConnectionIndex = None,
first_connect_idx: ConnectionIndex = (),
second_connect_idx: ConnectionIndex = (),
) -> ModuleProtocol:
"""
This method connects two rod-like objects using the selected joint class.
Expand All @@ -56,9 +56,9 @@ def connect(
Rod-like object
second_rod : RodType | RigidBodyType
Rod-like object
first_connect_idx : Optional[int]
first_connect_idx : ConnectionIndex
Index of first rod for joint.
second_connect_idx : Optional[int]
second_connect_idx : ConnectionIndex
Index of second rod for joint.
Returns
Expand Down Expand Up @@ -173,8 +173,8 @@ def __init__(
self._second_sys_idx: SystemIdxType = second_sys_idx
self._first_sys_n_lim: int = first_sys_nlim
self._second_sys_n_lim: int = second_sys_nlim
self.first_sys_connection_idx: ConnectionIndex = None
self.second_sys_connection_idx: ConnectionIndex = None
self.first_sys_connection_idx: ConnectionIndex = ()
self.second_sys_connection_idx: ConnectionIndex = ()
self._connect_cls: Type[FreeJoint]

def set_index(
Expand All @@ -188,7 +188,7 @@ def set_index(
), f"Type of first_connect_idx :{first_type} is different than second_connect_idx :{second_type}"

# Check if the type of idx variables are correct.
allow_types = (int, np.int_, list, tuple, np.ndarray, type(None))
allow_types = (int, np.int_, list, tuple, np.ndarray)
assert isinstance(
first_idx, allow_types
), f"Connection index type is not supported :{first_type}, please try one of the following :{allow_types}"
Expand Down Expand Up @@ -227,10 +227,7 @@ def set_index(
), "Connection index of second rod exceeds its dof : {}".format(
self._second_sys_n_lim
)
elif first_idx is None:
# Do nothing if idx are None
pass
else:
elif isinstance(first_idx, (int, np.int_)):
# The addition of +1 and and <= check on the RHS is because
# connections can be made to the node indices as well
first_idx__ = cast(int, first_idx)
Expand All @@ -245,6 +242,10 @@ def set_index(
), "Connection index of second rod exceeds its dof : {}".format(
self._second_sys_n_lim
)
else:
raise TypeError(
"Connection index type is not supported :{}".format(first_type)
)

self.first_sys_connection_idx = first_idx
self.second_sys_connection_idx = second_idx
Expand Down
5 changes: 3 additions & 2 deletions elastica/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@

# Indexing types
# TODO: Maybe just use slice??
ConstrainingIndex: TypeAlias = list[int] | tuple[int, ...] | np.typing.NDArray
ConstrainingIndex: TypeAlias = tuple[int, ...]
ConnectionIndex: TypeAlias = (
int | np.int_ | list[int] | tuple[int] | np.typing.NDArray | None
int | np.int_ | list[int] | tuple[int, ...] | np.typing.NDArray[np.int32]
)

# Operators in elastica.modules
# TODO: can be more specific.
OperatorParam = ParamSpec("OperatorParam")
OperatorCallbackType: TypeAlias = Callable[..., None]
OperatorFinalizeType: TypeAlias = Callable[..., None]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_modules/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_set_index_with_illegal_type_second_idx_throws(

# Below test is to increase code coverage. If we pass nothing or idx=None, then do nothing.
def test_set_index_no_input(self, load_connect):
load_connect.set_index(first_idx=None, second_idx=None)
load_connect.set_index(first_idx=(), second_idx=())

@pytest.mark.parametrize(
"legal_idx", [(80, 80), (0, 50), (50, 0), (-20, -20), (-20, 50), (-50, -20)]
Expand Down Expand Up @@ -291,7 +291,8 @@ def test_connect_registers_and_returns_Connect(self, load_system_with_connects):
assert _mock_connect in system_collection_with_connections._connections
assert _mock_connect.__class__ == _Connect
# check sane defaults provided for connection indices
assert _mock_connect.id()[2] is None and _mock_connect.id()[3] is None
assert _mock_connect.id()[2] == ()
assert _mock_connect.id()[3] == ()

from elastica.joint import FreeJoint

Expand Down

0 comments on commit 0bd4705

Please sign in to comment.