diff --git a/elastica/rod/cosserat_rod.py b/elastica/rod/cosserat_rod.py index 051b3751..7d8a4f4a 100644 --- a/elastica/rod/cosserat_rod.py +++ b/elastica/rod/cosserat_rod.py @@ -2,6 +2,7 @@ import numpy as np +from numpy.typing import NDArray import functools import numba from elastica.rod import RodBase @@ -20,18 +21,21 @@ _difference, _average, ) -from typing import Optional +from typing import Any, Optional +from typing_extensions import Self + +from elastica.typing import RodType position_difference_kernel = _difference position_average = _average @functools.lru_cache(maxsize=1) -def _get_z_vector(): +def _get_z_vector() -> NDArray[np.floating]: return np.array([0.0, 0.0, 1.0]).reshape(3, -1) -def _compute_sigma_kappa_for_blockstructure(memory_block): +def _compute_sigma_kappa_for_blockstructure(memory_block: RodType) -> None: """ This function is a wrapper to call functions which computes shear stretch, strain and bending twist and strain. @@ -147,39 +151,39 @@ class CosseratRod(RodBase, KnotTheory): def __init__( self, - n_elements, - position, - velocity, - omega, - acceleration, - angular_acceleration, - directors, - radius, - mass_second_moment_of_inertia, - inv_mass_second_moment_of_inertia, - shear_matrix, - bend_matrix, - density, - volume, - mass, - internal_forces, - internal_torques, - external_forces, - external_torques, - lengths, - rest_lengths, - tangents, - dilatation, - dilatation_rate, - voronoi_dilatation, - rest_voronoi_lengths, - sigma, - kappa, - rest_sigma, - rest_kappa, - internal_stress, - internal_couple, - ring_rod_flag, + n_elements: int, + position: NDArray[np.floating], + velocity: NDArray[np.floating], + omega: NDArray[np.floating], + acceleration: NDArray[np.floating], + angular_acceleration: NDArray[np.floating], + directors: NDArray[np.floating], + radius: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + density: NDArray[np.floating], + volume: NDArray[np.floating], + mass: NDArray[np.floating], + internal_forces: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_forces: NDArray[np.floating], + external_torques: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + sigma: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + ring_rod_flag: bool, ): self.n_elems = n_elements self.position_collection = position @@ -242,17 +246,17 @@ def __init__( def straight_rod( cls, n_elements: int, - start: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + start: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, *, nu: Optional[float] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -390,17 +394,17 @@ def straight_rod( def ring_rod( cls, n_elements: int, - ring_center_position: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + ring_center_position: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, *, nu: Optional[float] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -536,7 +540,7 @@ def ring_rod( rod.REQUISITE_MODULE.append(Constraints) return rod - def compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time: float) -> None: """ Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques @@ -591,7 +595,7 @@ def compute_internal_forces_and_torques(self, time): ) # Interface to time-stepper mixins (Symplectic, Explicit), which calls this method - def update_accelerations(self, time): + def update_accelerations(self, time: float) -> None: """ Updates the acceleration variables @@ -613,12 +617,12 @@ def update_accelerations(self, time): self.dilatation, ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time: np.floating) -> None: _zeroed_out_external_forces_and_torques( self.external_forces, self.external_torques ) - def compute_translational_energy(self): + def compute_translational_energy(self) -> NDArray[np.floating]: """ Compute total translational energy of the rod at the instance. """ @@ -632,7 +636,7 @@ def compute_translational_energy(self): ).sum() ) - def compute_rotational_energy(self): + def compute_rotational_energy(self) -> NDArray[np.floating]: """ Compute total rotational energy of the rod at the instance. """ @@ -642,7 +646,7 @@ def compute_rotational_energy(self): ) return 0.5 * np.einsum("ik,ik->k", self.omega_collection, J_omega_upon_e).sum() - def compute_velocity_center_of_mass(self): + def compute_velocity_center_of_mass(self) -> NDArray[np.floating]: """ Compute velocity center of mass of the rod at the instance. """ @@ -651,7 +655,7 @@ def compute_velocity_center_of_mass(self): return sum_mass_times_velocity / self.mass.sum() - def compute_position_center_of_mass(self): + def compute_position_center_of_mass(self) -> NDArray[np.floating]: """ Compute position center of mass of the rod at the instance. """ @@ -660,7 +664,7 @@ def compute_position_center_of_mass(self): return sum_mass_times_position / self.mass.sum() - def compute_bending_energy(self): + def compute_bending_energy(self) -> NDArray[np.floating]: """ Compute total bending energy of the rod at the instance. """ @@ -676,7 +680,7 @@ def compute_bending_energy(self): ).sum() ) - def compute_shear_energy(self): + def compute_shear_energy(self) -> NDArray[np.floating]: """ Compute total shear energy of the rod at the instance. """ @@ -695,8 +699,12 @@ def compute_shear_energy(self): @numba.njit(cache=True) def _compute_geometry_from_state( - position_collection, volume, lengths, tangents, radius -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], +) -> None: """ Update given . """ @@ -719,16 +727,16 @@ def _compute_geometry_from_state( @numba.njit(cache=True) def _compute_all_dilatations( - position_collection, - volume, - lengths, - tangents, - radius, - dilatation, - rest_lengths, - rest_voronoi_lengths, - voronoi_dilatation, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + dilatation: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], +) -> None: """ Update """ @@ -749,8 +757,12 @@ def _compute_all_dilatations( @numba.njit(cache=True) def _compute_dilatation_rate( - position_collection, velocity_collection, lengths, rest_lengths, dilatation_rate -): + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], +) -> None: """ Update dilatation_rate given position, velocity, length, and rest_length """ @@ -776,18 +788,18 @@ def _compute_dilatation_rate( @numba.njit(cache=True) def _compute_shear_stretch_strains( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], +) -> None: """ Update given . """ @@ -811,21 +823,21 @@ def _compute_shear_stretch_strains( @numba.njit(cache=True) def _compute_internal_shear_stretch_stresses_from_model( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + internal_stress: NDArray[np.floating], +) -> None: """ Update given . @@ -850,7 +862,11 @@ def _compute_internal_shear_stretch_stresses_from_model( @numba.njit(cache=True) -def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, kappa): +def _compute_bending_twist_strains( + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + kappa: NDArray[np.floating], +) -> None: """ Update given . """ @@ -864,13 +880,13 @@ def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, ka @numba.njit(cache=True) def _compute_internal_bending_twist_stresses_from_model( - director_collection, - rest_voronoi_lengths, - internal_couple, - bend_matrix, - kappa, - rest_kappa, -): + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + internal_couple: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_kappa: NDArray[np.floating], +) -> None: """ Upate given . @@ -893,23 +909,23 @@ def _compute_internal_bending_twist_stresses_from_model( @numba.njit(cache=True) def _compute_internal_forces( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, - internal_forces, - ghost_elems_idx, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_forces: NDArray[np.floating], + ghost_elems_idx: NDArray[np.floating], +) -> None: """ Update given . """ @@ -954,26 +970,26 @@ def _compute_internal_forces( @numba.njit(cache=True) def _compute_internal_torques( - position_collection, - velocity_collection, - tangents, - lengths, - rest_lengths, - director_collection, - rest_voronoi_lengths, - bend_matrix, - rest_kappa, - kappa, - voronoi_dilatation, - mass_second_moment_of_inertia, - omega_collection, - internal_stress, - internal_couple, - dilatation, - dilatation_rate, - internal_torques, - ghost_voronoi_idx, -): + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + tangents: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + kappa: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + omega_collection: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + internal_torques: NDArray[np.floating], + ghost_voronoi_idx: NDArray[np.integer], +) -> None: """ Update . """ @@ -1043,16 +1059,16 @@ def _compute_internal_torques( @numba.njit(cache=True) def _update_accelerations( - acceleration_collection, - internal_forces, - external_forces, - mass, - alpha_collection, - inv_mass_second_moment_of_inertia, - internal_torques, - external_torques, - dilatation, -): + acceleration_collection: NDArray[np.floating], + internal_forces: NDArray[np.floating], + external_forces: NDArray[np.floating], + mass: NDArray[np.floating], + alpha_collection: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_torques: NDArray[np.floating], + dilatation: NDArray[np.floating], +) -> None: """ Update given . """ @@ -1077,7 +1093,9 @@ def _update_accelerations( @numba.njit(cache=True) -def _zeroed_out_external_forces_and_torques(external_forces, external_torques): +def _zeroed_out_external_forces_and_torques( + external_forces: NDArray[np.floating], external_torques: NDArray[np.floating] +) -> None: """ This function is to zeroed out external forces and torques. diff --git a/elastica/rod/data_structures.py b/elastica/rod/data_structures.py index 3c8b4032..c65a9c4b 100644 --- a/elastica/rod/data_structures.py +++ b/elastica/rod/data_structures.py @@ -1,6 +1,9 @@ __doc__ = "Data structure wrapper for rod components" +from typing import Any, Optional +from typing_extensions import Self import numpy as np +from numpy.typing import NDArray from numba import njit from elastica._rotations import _get_rotation_matrix, _rotate from elastica._linalg import _batch_matmul @@ -9,7 +12,7 @@ # FIXME : Explicit Stepper doesn't work as States lose the # views they initially had when working with a timestepper. # class _RodExplicitStepperMixin: -# def __init__(self): +# def __init__(self) -> None: # ( # self.state, # self.__deriv_state, @@ -43,7 +46,7 @@ class _RodSymplecticStepperMixin: - def __init__(self): + def __init__(self) -> None: self.kinematic_states = _KinematicState( self.position_collection, self.director_collection ) @@ -60,18 +63,40 @@ def __init__(self): # is another function self.kinematic_rates = self.dynamic_states.kinematic_rates - def update_internal_forces_and_torques(self, time, *args, **kwargs): + def update_internal_forces_and_torques( + self, time: np.floating, *args: Any, **kwargs: Any + ) -> None: self.compute_internal_forces_and_torques(time) - def dynamic_rates(self, time, prefac, *args, **kwargs): + def dynamic_rates( + self, time: np.floating, prefac: np.floating, *args: Any, **kwargs: Any + ) -> NDArray[np.floating]: self.update_accelerations(time) return self.dynamic_states.dynamic_rates(time, prefac, *args, **kwargs) - def reset_external_forces_and_torques(self, time, *args, **kwargs): + def reset_external_forces_and_torques( + self, time: np.floating, *args: Any, **kwargs: Any + ) -> None: self.zeroed_out_external_forces_and_torques(time) -def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_states): +def _bootstrap_from_data( + stepper_type: str, + n_elems: int, + vector_states: NDArray[np.floating], + matrix_states: NDArray[np.floating], +) -> Optional[ + tuple[ + "_State", + "_DerivativeState", + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + ] +]: """Returns states wrapping numpy arrays based on the time-stepping algorithm Convenience method that takes in rod internal (raw np.ndarray) data, create views @@ -119,7 +144,7 @@ def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_ # ) raise NotImplementedError else: - return + return None n_velocity_end = n_nodes + n_nodes velocity = np.ndarray.view(vector_states[..., n_nodes:n_velocity_end]) @@ -154,10 +179,10 @@ class _State: def __init__( self, n_elems: int, - position_collection_view, - director_collection_view, - kinematic_rate_collection_view, - ): + position_collection_view: NDArray[np.floating], + director_collection_view: NDArray[np.floating], + kinematic_rate_collection_view: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -173,7 +198,7 @@ def __init__( self.director_collection = director_collection_view self.kinematic_rate_collection = kinematic_rate_collection_view - def __iadd__(self, scaled_deriv_array): + def __iadd__(self, scaled_deriv_array: NDArray[np.floating]) -> Self: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -242,7 +267,7 @@ def __iadd__(self, scaled_deriv_array): ] return self - def __add__(self, scaled_derivative_state): + def __add__(self, scaled_derivative_state: NDArray[np.floating]) -> "_State": """overloaded + operator, useful in state.k1 = state + dt * deriv_state The add for directors is customized to reflect Rodrigues' rotation @@ -296,7 +321,9 @@ class _DerivativeState: /multiplication used. """ - def __init__(self, _unused_n_elems: int, rate_collection_view): + def __init__( + self, _unused_n_elems: int, rate_collection_view: NDArray[np.floating] + ) -> None: """ Parameters ---------- @@ -307,7 +334,7 @@ def __init__(self, _unused_n_elems: int, rate_collection_view): super(_DerivativeState, self).__init__() self.rate_collection = rate_collection_view - def __rmul__(self, scalar): + def __rmul__(self, scalar: np.floating) -> NDArray[np.floating]: """overloaded scalar * self, Parameters @@ -355,7 +382,7 @@ def __rmul__(self, scalar): """ return scalar * self.rate_collection - def __mul__(self, scalar): + def __mul__(self, scalar: np.floating) -> NDArray[np.floating]: """overloaded self * scalar TODO Check if this pattern (forwarding to __mul__) has @@ -388,7 +415,11 @@ class _KinematicState: only these methods are provided. """ - def __init__(self, position_collection_view, director_collection_view): + def __init__( + self, + position_collection_view: NDArray[np.floating], + director_collection_view: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -403,13 +434,13 @@ def __init__(self, position_collection_view, director_collection_view): @njit(cache=True) def overload_operator_kinematic_numba( - n_nodes, - prefac, - position_collection, - director_collection, - velocity_collection, - omega_collection, -): + n_nodes: int, + prefac: np.floating, + position_collection: NDArray[np.floating], + director_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], +) -> None: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -449,11 +480,11 @@ class _DynamicState: def __init__( self, - v_w_collection, - dvdt_dwdt_collection, - velocity_collection, - omega_collection, - ): + v_w_collection: NDArray[np.floating], + dvdt_dwdt_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -470,7 +501,9 @@ def __init__( self.velocity_collection = velocity_collection self.omega_collection = omega_collection - def kinematic_rates(self, time, prefac): + def kinematic_rates( + self, time: np.floating, prefac: np.floating + ) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """Yields kinematic rates to interact with _KinematicState Returns @@ -486,7 +519,9 @@ def kinematic_rates(self, time, prefac): # Comes from kin_state -> (x,Q) += dt * (v,w) <- First part of dyn_state return self.velocity_collection, self.omega_collection - def dynamic_rates(self, time, prefac): + def dynamic_rates( + self, time: np.floating, prefac: np.floating + ) -> NDArray[np.floating]: """Yields dynamic rates to add to with _DynamicState Returns ------- @@ -501,7 +536,10 @@ def dynamic_rates(self, time, prefac): @njit(cache=True) -def overload_operator_dynamic_numba(rate_collection, scaled_second_deriv_array): +def overload_operator_dynamic_numba( + rate_collection: NDArray[np.floating], + scaled_second_deriv_array: NDArray[np.floating], +) -> None: """overloaded += operator, updating dynamic_rates Parameters ---------- diff --git a/elastica/rod/factory_function.py b/elastica/rod/factory_function.py index 6eca6b5a..22cc33bc 100644 --- a/elastica/rod/factory_function.py +++ b/elastica/rod/factory_function.py @@ -1,30 +1,64 @@ __doc__ = """ Factory function to allocate variables for Cosserat Rod""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import logging import numpy as np from numpy.testing import assert_allclose +from numpy.typing import NDArray from elastica.utils import MaxDimension, Tolerance from elastica._linalg import _batch_cross, _batch_norm, _batch_dot def allocate( - n_elements, - direction, - normal, - base_length, - base_radius, - density, - youngs_modulus: float, + n_elements: int, + direction: NDArray[np.floating], + normal: NDArray[np.floating], + base_length: np.floating, + base_radius: np.floating, + density: np.floating, + youngs_modulus: np.floating, *, rod_origin_position: np.ndarray, ring_rod_flag: bool, - shear_modulus: Optional[float] = None, + shear_modulus: Optional[np.floating] = None, position: Optional[np.ndarray] = None, directors: Optional[np.ndarray] = None, rest_sigma: Optional[np.ndarray] = None, rest_kappa: Optional[np.ndarray] = None, - **kwargs, -): + **kwargs: Any, +) -> tuple[ + int, + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], +]: log = logging.getLogger() if "poisson_ratio" in kwargs: @@ -335,14 +369,16 @@ def allocate( """ -def _assert_dim(vector, max_dim: int, name: str): +def _assert_dim(vector: np.ndarray, max_dim: int, name: str) -> None: assert vector.ndim < max_dim, ( f"Input {name} dimension is not correct {vector.shape}" + f" It should be maximum {max_dim}D vector or single floating number." ) -def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): +def _assert_shape( + array: np.ndarray, expected_shape: Tuple[int, ...], name: str +) -> None: assert array.shape == expected_shape, ( f"Given {name} shape is not correct, it should be " + str(expected_shape) @@ -351,7 +387,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): ) -def _position_validity_checker(position, start, n_elements): +def _position_validity_checker( + position: NDArray[np.floating], start: NDArray[np.floating], n_elements: int +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements + 1), "position") @@ -367,7 +405,9 @@ def _position_validity_checker(position, start, n_elements): ) -def _directors_validity_checker(directors, tangents, n_elements): +def _directors_validity_checker( + directors: NDArray[np.floating], tangents: NDArray[np.floating], n_elements: int +) -> None: """Checker on user-defined directors validity""" _assert_shape( directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors" @@ -413,7 +453,11 @@ def _directors_validity_checker(directors, tangents, n_elements): ) -def _position_validity_checker_ring_rod(position, ring_center_position, n_elements): +def _position_validity_checker_ring_rod( + position: NDArray[np.floating], + ring_center_position: NDArray[np.floating], + n_elements: int, +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements), "position") diff --git a/elastica/rod/knot_theory.py b/elastica/rod/knot_theory.py index 23d2b009..537d8d4e 100644 --- a/elastica/rod/knot_theory.py +++ b/elastica/rod/knot_theory.py @@ -14,6 +14,7 @@ from numba import njit import numpy as np +from numpy.typing import NDArray from elastica.rod.rod_base import RodBase from elastica._linalg import _batch_norm, _batch_dot, _batch_cross @@ -38,7 +39,7 @@ def radius(self) -> np.ndarray: ... def base_length(self) -> np.ndarray: ... -class KnotTheory: +class KnotTheory(KnotTheoryCompatibleProtocol): """ This mixin should be used in RodBase-derived class that satisfies KnotCompatibleProtocol. The theory behind this module is based on the method from Klenin & Langowski 2000 paper. @@ -46,7 +47,7 @@ class KnotTheory: KnotTheory can be mixed with any rod-class based on RodBase:: class MyRod(RodBase, KnotTheory): - def __init__(self): + def __init__(self) -> None: super().__init__() rod = MyRod(...) @@ -78,7 +79,7 @@ def __init__(self): MIXIN_PROTOCOL = Union[RodBase, KnotTheoryCompatibleProtocol] - def compute_twist(self: MIXIN_PROTOCOL): + def compute_twist(self: MIXIN_PROTOCOL) -> NDArray[np.floating]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. """ @@ -91,8 +92,8 @@ def compute_twist(self: MIXIN_PROTOCOL): def compute_writhe( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, - ): + alpha: np.floating = 1.0, + ) -> NDArray[np.floating]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -114,8 +115,8 @@ def compute_writhe( def compute_link( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, - ): + alpha: np.floating = 1.0, + ) -> NDArray[np.floating]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -138,7 +139,9 @@ def compute_link( )[0] -def compute_twist(center_line, normal_collection): +def compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Compute the twist of a rod, using center_line and normal collection. @@ -189,7 +192,9 @@ def compute_twist(center_line, normal_collection): @njit(cache=True) -def _compute_twist(center_line, normal_collection): +def _compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Parameters ---------- @@ -264,7 +269,11 @@ def _compute_twist(center_line, normal_collection): return total_twist, local_twist -def compute_writhe(center_line, segment_length, type_of_additional_segment): +def compute_writhe( + center_line: NDArray[np.floating], + segment_length: np.floating, + type_of_additional_segment: str, +) -> NDArray[np.floating]: """ This function computes the total writhe history of a rod. @@ -314,7 +323,7 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment): @njit(cache=True) -def _compute_writhe(center_line): +def _compute_writhe(center_line: NDArray[np.floating]) -> NDArray[np.floating]: """ Parameters ---------- @@ -386,12 +395,12 @@ def _compute_writhe(center_line): def compute_link( - center_line: np.ndarray, - normal_collection: np.ndarray, - radius: np.ndarray, - segment_length: float, + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], + segment_length: np.floating, type_of_additional_segment: str, -): +) -> NDArray[np.floating]: """ This function computes the total link history of a rod. @@ -470,7 +479,11 @@ def compute_link( @njit(cache=True) -def _compute_auxiliary_line(center_line, normal_collection, radius): +def _compute_auxiliary_line( + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function computes the auxiliary line using rod center line and normal collection. @@ -525,7 +538,9 @@ def _compute_auxiliary_line(center_line, normal_collection, radius): @njit(cache=True) -def _compute_link(center_line, auxiliary_line): +def _compute_link( + center_line: NDArray[np.floating], auxiliary_line: NDArray[np.floating] +) -> NDArray[np.floating]: """ Parameters @@ -604,8 +619,11 @@ def _compute_link(center_line, auxiliary_line): @njit(cache=True) def _compute_auxiliary_line_added_segments( - beginning_direction, end_direction, auxiliary_line, segment_length -): + beginning_direction: NDArray[np.floating], + end_direction: NDArray[np.floating], + auxiliary_line: NDArray[np.floating], + segment_length: np.floating, +) -> NDArray[np.floating]: """ This code is for computing position of added segments to the auxiliary line. @@ -647,8 +665,10 @@ def _compute_auxiliary_line_added_segments( @njit(cache=True) def _compute_additional_segment( - center_line, segment_length, type_of_additional_segment -): + center_line: NDArray[np.floating], + segment_length: np.floating, + type_of_additional_segment: str, +) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: """ This function adds two points at the end of center line. Distance from the center line is given by segment_length. Direction from center line to the new point locations can be computed using 3 methods, which can be selected by diff --git a/elastica/rod/rod_base.py b/elastica/rod/rod_base.py index 314efe6d..d49c0431 100644 --- a/elastica/rod/rod_base.py +++ b/elastica/rod/rod_base.py @@ -1,5 +1,9 @@ __doc__ = """Base class for rods""" +from typing import Any +import numpy as np +from numpy.typing import NDArray + class RodBase: """ @@ -11,10 +15,15 @@ class RodBase: """ - REQUISITE_MODULES = [] + REQUISITE_MODULES: list[Any] = [] - def __init__(self): + def __init__(self) -> None: """ RodBase does not take any arguments. """ - pass + self.position_collection: NDArray[np.floating] + self.omega_collection: NDArray[np.floating] + self.acceleration_collection: NDArray[np.floating] + self.alpha_collection: NDArray[np.floating] + self.external_forces: NDArray[np.floating] + self.external_torques: NDArray[np.floating]