Skip to content

Commit

Permalink
Start work on rod type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
ankith26 committed May 7, 2024
1 parent 17156b9 commit 7afa435
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 70 deletions.
85 changes: 43 additions & 42 deletions elastica/rod/cosserat_rod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import numpy as np
from numpy.typing import NDArray
import functools
import numba
from elastica.rod import RodBase
Expand Down Expand Up @@ -147,39 +148,39 @@ class CosseratRod(RodBase, KnotTheory):

def __init__(
self,
n_elements,
position,
velocity,
omega,
acceleration,
angular_acceleration,
directors,
radius,
mass_second_moment_of_inertia,
inv_mass_second_moment_of_inertia,
shear_matrix,
bend_matrix,
density,
volume,
mass,
internal_forces,
internal_torques,
external_forces,
external_torques,
lengths,
rest_lengths,
tangents,
dilatation,
dilatation_rate,
voronoi_dilatation,
rest_voronoi_lengths,
sigma,
kappa,
rest_sigma,
rest_kappa,
internal_stress,
internal_couple,
ring_rod_flag,
n_elements: int,
position: NDArray[np.floating],
velocity: NDArray[np.floating],
omega: NDArray[np.floating],
acceleration: NDArray[np.floating],
angular_acceleration: NDArray[np.floating],
directors: NDArray[np.floating],
radius: NDArray[np.floating],
mass_second_moment_of_inertia: NDArray[np.floating],
inv_mass_second_moment_of_inertia: NDArray[np.floating],
shear_matrix: NDArray[np.floating],
bend_matrix: NDArray[np.floating],
density: NDArray[np.floating],
volume: NDArray[np.floating],
mass: NDArray[np.floating],
internal_forces: NDArray[np.floating],
internal_torques: NDArray[np.floating],
external_forces: NDArray[np.floating],
external_torques: NDArray[np.floating],
lengths: NDArray[np.floating],
rest_lengths: NDArray[np.floating],
tangents: NDArray[np.floating],
dilatation: NDArray[np.floating],
dilatation_rate: NDArray[np.floating],
voronoi_dilatation: NDArray[np.floating],
rest_voronoi_lengths: NDArray[np.floating],
sigma: NDArray[np.floating],
kappa: NDArray[np.floating],
rest_sigma: NDArray[np.floating],
rest_kappa: NDArray[np.floating],
internal_stress: NDArray[np.floating],
internal_couple: NDArray[np.floating],
ring_rod_flag: bool,
):
self.n_elems = n_elements
self.position_collection = position
Expand Down Expand Up @@ -242,9 +243,9 @@ def __init__(
def straight_rod(
cls,
n_elements: int,
start: np.ndarray,
direction: np.ndarray,
normal: np.ndarray,
start: NDArray[np.floating],
direction: NDArray[np.floating],
normal: NDArray[np.floating],
base_length: float,
base_radius: float,
density: float,
Expand Down Expand Up @@ -390,9 +391,9 @@ def straight_rod(
def ring_rod(
cls,
n_elements: int,
ring_center_position: np.ndarray,
direction: np.ndarray,
normal: np.ndarray,
ring_center_position: NDArray[np.floating],
direction: NDArray[np.floating],
normal: NDArray[np.floating],
base_length: float,
base_radius: float,
density: float,
Expand Down Expand Up @@ -533,7 +534,7 @@ def ring_rod(
ring_rod_flag,
)

def compute_internal_forces_and_torques(self, time):
def compute_internal_forces_and_torques(self, time: float):
"""
Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because
they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques
Expand Down Expand Up @@ -588,7 +589,7 @@ def compute_internal_forces_and_torques(self, time):
)

# Interface to time-stepper mixins (Symplectic, Explicit), which calls this method
def update_accelerations(self, time):
def update_accelerations(self, time: float):
"""
Updates the acceleration variables
Expand All @@ -610,7 +611,7 @@ def update_accelerations(self, time):
self.dilatation,
)

def zeroed_out_external_forces_and_torques(self, time):
def zeroed_out_external_forces_and_torques(self, time: float):
_zeroed_out_external_forces_and_torques(
self.external_forces, self.external_torques
)
Expand Down
31 changes: 20 additions & 11 deletions elastica/rod/factory_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import logging
import numpy as np
from numpy.testing import assert_allclose
from numpy.typing import NDArray
from elastica.utils import MaxDimension, Tolerance
from elastica._linalg import _batch_cross, _batch_norm, _batch_dot


def allocate(
n_elements,
direction,
normal,
base_length,
base_radius,
density,
n_elements: int,
direction: NDArray[np.floating],
normal: NDArray[np.floating],
base_length: float,
base_radius: float,
density: float,
youngs_modulus: float,
*,
rod_origin_position: np.ndarray,
Expand Down Expand Up @@ -335,14 +336,14 @@ def allocate(
"""


def _assert_dim(vector, max_dim: int, name: str):
def _assert_dim(vector: np.ndarray, max_dim: int, name: str):
assert vector.ndim < max_dim, (
f"Input {name} dimension is not correct {vector.shape}"
+ f" It should be maximum {max_dim}D vector or single floating number."
)


def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str):
def _assert_shape(array: np.ndarray, expected_shape: Tuple[int, ...], name: str):
assert array.shape == expected_shape, (
f"Given {name} shape is not correct, it should be "
+ str(expected_shape)
Expand All @@ -351,7 +352,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str):
)


def _position_validity_checker(position, start, n_elements):
def _position_validity_checker(
position: NDArray[np.floating], start: NDArray[np.floating], n_elements: int
):
"""Checker on user-defined position validity"""
_assert_shape(position, (MaxDimension.value(), n_elements + 1), "position")

Expand All @@ -367,7 +370,9 @@ def _position_validity_checker(position, start, n_elements):
)


def _directors_validity_checker(directors, tangents, n_elements):
def _directors_validity_checker(
directors: NDArray[np.floating], tangents: NDArray[np.floating], n_elements: int
):
"""Checker on user-defined directors validity"""
_assert_shape(
directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors"
Expand Down Expand Up @@ -413,7 +418,11 @@ def _directors_validity_checker(directors, tangents, n_elements):
)


def _position_validity_checker_ring_rod(position, ring_center_position, n_elements):
def _position_validity_checker_ring_rod(
position: NDArray[np.floating],
ring_center_position: NDArray[np.floating],
n_elements: int,
):
"""Checker on user-defined position validity"""
_assert_shape(position, (MaxDimension.value(), n_elements), "position")

Expand Down
42 changes: 31 additions & 11 deletions elastica/rod/knot_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from numba import njit
import numpy as np
from numpy.typing import NDArray

from elastica.rod.rod_base import RodBase
from elastica._linalg import _batch_norm, _batch_dot, _batch_cross
Expand Down Expand Up @@ -138,7 +139,9 @@ def compute_link(
)[0]


def compute_twist(center_line, normal_collection):
def compute_twist(
center_line: NDArray[np.floating], normal_collection: NDArray[np.floating]
):
"""
Compute the twist of a rod, using center_line and normal collection.
Expand Down Expand Up @@ -189,7 +192,9 @@ def compute_twist(center_line, normal_collection):


@njit(cache=True)
def _compute_twist(center_line, normal_collection):
def _compute_twist(
center_line: NDArray[np.floating], normal_collection: NDArray[np.floating]
):
"""
Parameters
----------
Expand Down Expand Up @@ -264,7 +269,11 @@ def _compute_twist(center_line, normal_collection):
return total_twist, local_twist


def compute_writhe(center_line, segment_length, type_of_additional_segment):
def compute_writhe(
center_line: NDArray[np.floating],
segment_length: float,
type_of_additional_segment: str,
):
"""
This function computes the total writhe history of a rod.
Expand Down Expand Up @@ -314,7 +323,7 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment):


@njit(cache=True)
def _compute_writhe(center_line):
def _compute_writhe(center_line: NDArray[np.floating]):
"""
Parameters
----------
Expand Down Expand Up @@ -386,9 +395,9 @@ def _compute_writhe(center_line):


def compute_link(
center_line: np.ndarray,
normal_collection: np.ndarray,
radius: np.ndarray,
center_line: NDArray[np.floating],
normal_collection: NDArray[np.floating],
radius: NDArray[np.floating],
segment_length: float,
type_of_additional_segment: str,
):
Expand Down Expand Up @@ -470,7 +479,11 @@ def compute_link(


@njit(cache=True)
def _compute_auxiliary_line(center_line, normal_collection, radius):
def _compute_auxiliary_line(
center_line: NDArray[np.floating],
normal_collection: NDArray[np.floating],
radius: NDArray[np.floating],
):
"""
This function computes the auxiliary line using rod center line and normal collection.
Expand Down Expand Up @@ -525,7 +538,9 @@ def _compute_auxiliary_line(center_line, normal_collection, radius):


@njit(cache=True)
def _compute_link(center_line, auxiliary_line):
def _compute_link(
center_line: NDArray[np.floating], auxiliary_line: NDArray[np.floating]
):
"""
Parameters
Expand Down Expand Up @@ -604,7 +619,10 @@ def _compute_link(center_line, auxiliary_line):

@njit(cache=True)
def _compute_auxiliary_line_added_segments(
beginning_direction, end_direction, auxiliary_line, segment_length
beginning_direction: NDArray[np.floating],
end_direction: NDArray[np.floating],
auxiliary_line: NDArray[np.floating],
segment_length: float,
):
"""
This code is for computing position of added segments to the auxiliary line.
Expand Down Expand Up @@ -647,7 +665,9 @@ def _compute_auxiliary_line_added_segments(

@njit(cache=True)
def _compute_additional_segment(
center_line, segment_length, type_of_additional_segment
center_line: NDArray[np.floating],
segment_length: float,
type_of_additional_segment: str,
):
"""
This function adds two points at the end of center line. Distance from the center line is given by segment_length.
Expand Down
14 changes: 8 additions & 6 deletions elastica/rod/rod_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__doc__ = """Base class for rods"""

import numpy as np
from numpy.typing import NDArray

class RodBase:
"""
Expand All @@ -15,9 +17,9 @@ def __init__(self) -> None:
"""
RodBase does not take any arguments.
"""
self.position_collection: int
self.omega_collection: int
self.acceleration_collection: int
self.alpha_collection: int
self.external_forces: int
self.external_torques: int
self.position_collection: NDArray[np.floating]
self.omega_collection: NDArray[np.floating]
self.acceleration_collection: NDArray[np.floating]
self.alpha_collection: NDArray[np.floating]
self.external_forces: NDArray[np.floating]
self.external_torques: NDArray[np.floating]

0 comments on commit 7afa435

Please sign in to comment.