From 5041d0bdb12be2b08955575028a0ceca345e01e3 Mon Sep 17 00:00:00 2001 From: Ankith Date: Tue, 14 May 2024 13:07:38 +0530 Subject: [PATCH] Improve typehinting at root and rod directory --- elastica/_calculus.py | 32 ++- elastica/_contact_functions.py | 214 +++++++------- elastica/_linalg.py | 45 ++- elastica/_rotations.py | 48 ++-- elastica/_synchronize_periodic_boundary.py | 23 +- elastica/boundary_conditions.py | 113 +++++--- elastica/callback_functions.py | 25 +- elastica/contact_forces.py | 58 ++-- elastica/contact_utils.py | 89 +++--- elastica/dissipation.py | 47 +-- elastica/external_forces.py | 125 +++++--- elastica/interaction.py | 172 ++++++----- elastica/joint.py | 219 ++++++++------ elastica/restart.py | 18 +- elastica/rod/cosserat_rod.py | 320 +++++++++++---------- elastica/rod/data_structures.py | 102 ++++--- elastica/rod/factory_function.py | 76 +++-- elastica/rod/knot_theory.py | 62 ++-- elastica/rod/rod_base.py | 14 +- elastica/transformations.py | 21 +- elastica/typing.py | 2 +- elastica/utils.py | 28 +- 22 files changed, 1102 insertions(+), 751 deletions(-) diff --git a/elastica/_calculus.py b/elastica/_calculus.py index eca9829b..ae730da1 100644 --- a/elastica/_calculus.py +++ b/elastica/_calculus.py @@ -1,6 +1,8 @@ __doc__ = """ Quadrature and difference kernels """ +from typing import Any, Union import numpy as np from numpy import zeros, empty +from numpy.typing import NDArray from numba import njit from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import ( _reset_vector_ghost, @@ -9,15 +11,17 @@ @functools.lru_cache(maxsize=2) -def _get_zero_array(dim, ndim): +def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]: if ndim == 1: return 0.0 if ndim == 2: return np.zeros((dim, 1)) + return None + @njit(cache=True) -def _trapezoidal(array_collection): +def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way @@ -63,7 +67,9 @@ def _trapezoidal(array_collection): @njit(cache=True) -def _trapezoidal_for_block_structure(array_collection, ghost_idx): +def _trapezoidal_for_block_structure( + array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer] +) -> NDArray[np.floating]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form specifically for the block structure implementation and there is a reset function call, to reset @@ -115,7 +121,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx): @njit(cache=True) -def _two_point_difference(array_collection): +def _two_point_difference( + array_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function does differentiation. @@ -156,7 +164,9 @@ def _two_point_difference(array_collection): @njit(cache=True) -def _two_point_difference_for_block_structure(array_collection, ghost_idx): +def _two_point_difference_for_block_structure( + array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer] +) -> NDArray[np.floating]: """ This function does the differentiation, for Cosserat rod model equations. This form specifically for the block structure implementation and there is a reset function call, to @@ -207,7 +217,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx): @njit(cache=True) -def _difference(vector): +def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes difference between elements of a batch vector. @@ -238,7 +248,7 @@ def _difference(vector): @njit(cache=True) -def _average(vector): +def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes the average between elements of a vector. @@ -268,7 +278,9 @@ def _average(vector): @njit(cache=True) -def _clip_array(input_array, vmin, vmax): +def _clip_array( + input_array: NDArray[np.floating], vmin: np.floating, vmax: np.floating +) -> NDArray[np.floating]: """ This function clips an array values between user defined minimum and maximum @@ -304,7 +316,7 @@ def _clip_array(input_array, vmin, vmax): @njit(cache=True) -def _isnan_check(array): +def _isnan_check(array: NDArray[Any]) -> bool: """ This function checks if there is any nan inside the array. If there is nan, it returns True boolean. @@ -324,7 +336,7 @@ def _isnan_check(array): Python version: 2.24 µs ± 96.1 ns per loop This version: 479 ns ± 6.49 ns per loop """ - return np.isnan(array).any() + return bool(np.isnan(array).any()) position_difference_kernel = _difference diff --git a/elastica/_contact_functions.py b/elastica/_contact_functions.py index 245d9231..67bbcc2c 100644 --- a/elastica/_contact_functions.py +++ b/elastica/_contact_functions.py @@ -24,28 +24,29 @@ ) import numba import numpy as np +from numpy.typing import NDArray @numba.njit(cache=True) def _calculate_contact_forces_rod_cylinder( - x_collection_rod, - edge_collection_rod, - x_cylinder_center, - x_cylinder_tip, - edge_cylinder, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_cylinder, - external_torques_cylinder, - cylinder_director_collection, - velocity_rod, - velocity_cylinder, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, + x_collection_rod: NDArray[np.floating], + edge_collection_rod: NDArray[np.floating], + x_cylinder_center: NDArray[np.floating], + x_cylinder_tip: NDArray[np.floating], + edge_cylinder: NDArray[np.floating], + radii_sum: NDArray[np.floating], + length_sum: NDArray[np.floating], + internal_forces_rod: NDArray[np.floating], + external_forces_rod: NDArray[np.floating], + external_forces_cylinder: NDArray[np.floating], + external_torques_cylinder: NDArray[np.floating], + cylinder_director_collection: NDArray[np.floating], + velocity_rod: NDArray[np.floating], + velocity_cylinder: NDArray[np.floating], + contact_k: np.floating, + contact_nu: np.floating, + velocity_damping_coefficient: np.floating, + friction_coefficient: np.floating, ) -> None: # We already pass in only the first n_elem x n_points = x_collection_rod.shape[1] @@ -155,22 +156,22 @@ def _calculate_contact_forces_rod_cylinder( @numba.njit(cache=True) def _calculate_contact_forces_rod_rod( - x_collection_rod_one, - radius_rod_one, - length_rod_one, - tangent_rod_one, - velocity_rod_one, - internal_forces_rod_one, - external_forces_rod_one, - x_collection_rod_two, - radius_rod_two, - length_rod_two, - tangent_rod_two, - velocity_rod_two, - internal_forces_rod_two, - external_forces_rod_two, - contact_k, - contact_nu, + x_collection_rod_one: NDArray[np.floating], + radius_rod_one: NDArray[np.floating], + length_rod_one: NDArray[np.floating], + tangent_rod_one: NDArray[np.floating], + velocity_rod_one: NDArray[np.floating], + internal_forces_rod_one: NDArray[np.floating], + external_forces_rod_one: NDArray[np.floating], + x_collection_rod_two: NDArray[np.floating], + radius_rod_two: NDArray[np.floating], + length_rod_two: NDArray[np.floating], + tangent_rod_two: NDArray[np.floating], + velocity_rod_two: NDArray[np.floating], + internal_forces_rod_two: NDArray[np.floating], + external_forces_rod_two: NDArray[np.floating], + contact_k: np.floating, + contact_nu: np.floating, ) -> None: # We already pass in only the first n_elem x n_points_rod_one = x_collection_rod_one.shape[1] @@ -272,14 +273,14 @@ def _calculate_contact_forces_rod_rod( @numba.njit(cache=True) def _calculate_contact_forces_self_rod( - x_collection_rod, - radius_rod, - length_rod, - tangent_rod, - velocity_rod, - external_forces_rod, - contact_k, - contact_nu, + x_collection_rod: NDArray[np.floating], + radius_rod: NDArray[np.floating], + length_rod: NDArray[np.floating], + tangent_rod: NDArray[np.floating], + velocity_rod: NDArray[np.floating], + external_forces_rod: NDArray[np.floating], + contact_k: np.floating, + contact_nu: np.floating, ) -> None: # We already pass in only the first n_elem x n_points_rod = x_collection_rod.shape[1] @@ -360,24 +361,24 @@ def _calculate_contact_forces_self_rod( @numba.njit(cache=True) def _calculate_contact_forces_rod_sphere( - x_collection_rod, - edge_collection_rod, - x_sphere_center, - x_sphere_tip, - edge_sphere, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_sphere, - external_torques_sphere, - sphere_director_collection, - velocity_rod, - velocity_sphere, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, + x_collection_rod: NDArray[np.floating], + edge_collection_rod: NDArray[np.floating], + x_sphere_center: NDArray[np.floating], + x_sphere_tip: NDArray[np.floating], + edge_sphere: NDArray[np.floating], + radii_sum: NDArray[np.floating], + length_sum: NDArray[np.floating], + internal_forces_rod: NDArray[np.floating], + external_forces_rod: NDArray[np.floating], + external_forces_sphere: NDArray[np.floating], + external_torques_sphere: NDArray[np.floating], + sphere_director_collection: NDArray[np.floating], + velocity_rod: NDArray[np.floating], + velocity_sphere: NDArray[np.floating], + contact_k: np.floating, + contact_nu: np.floating, + velocity_damping_coefficient: np.floating, + friction_coefficient: np.floating, ) -> None: # We already pass in only the first n_elem x n_points = x_collection_rod.shape[1] @@ -486,18 +487,18 @@ def _calculate_contact_forces_rod_sphere( @numba.njit(cache=True) def _calculate_contact_forces_rod_plane( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - radius, - mass, - position_collection, - velocity_collection, - internal_forces, - external_forces, -): + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + surface_tol: np.floating, + k: np.floating, + nu: np.floating, + radius: NDArray[np.floating], + mass: NDArray[np.floating], + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + internal_forces: NDArray[np.floating], + external_forces: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.intp]]: """ This function computes the plane force response on the element, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -571,30 +572,30 @@ def _calculate_contact_forces_rod_plane( @numba.njit(cache=True) def _calculate_contact_forces_rod_plane_with_anisotropic_friction( - plane_origin, - plane_normal, - surface_tol, - slip_velocity_tol, - k, - nu, - kinetic_mu_forward, - kinetic_mu_backward, - kinetic_mu_sideways, - static_mu_forward, - static_mu_backward, - static_mu_sideways, - radius, - mass, - tangents, - position_collection, - director_collection, - velocity_collection, - omega_collection, - internal_forces, - external_forces, - internal_torques, - external_torques, -): + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + surface_tol: np.floating, + slip_velocity_tol: np.floating, + k: np.floating, + nu: np.floating, + kinetic_mu_forward: np.floating, + kinetic_mu_backward: np.floating, + kinetic_mu_sideways: np.floating, + static_mu_forward: np.floating, + static_mu_backward: np.floating, + static_mu_sideways: np.floating, + radius: NDArray[np.floating], + mass: NDArray[np.floating], + tangents: NDArray[np.floating], + position_collection: NDArray[np.floating], + director_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + internal_forces: NDArray[np.floating], + external_forces: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_torques: NDArray[np.floating], +) -> None: ( plane_response_force_mag, no_contact_point_idx, @@ -784,17 +785,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction( @numba.njit(cache=True) def _calculate_contact_forces_cylinder_plane( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): - + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + surface_tol: np.floating, + k: np.floating, + nu: np.floating, + length: NDArray[np.floating], + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + external_forces: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.intp]]: # Compute plane response force # total_forces = system.internal_forces + system.external_forces total_forces = external_forces diff --git a/elastica/_linalg.py b/elastica/_linalg.py index a4995ab2..3123ff75 100644 --- a/elastica/_linalg.py +++ b/elastica/_linalg.py @@ -1,5 +1,6 @@ __doc__ = """ Convenient linear algebra kernels """ import numpy as np +from numpy.typing import NDArray from numba import njit from numpy import sqrt import functools @@ -8,7 +9,7 @@ @functools.lru_cache(maxsize=1) -def levi_civita_tensor(dim): +def levi_civita_tensor(dim: int) -> NDArray[np.floating]: """ Parameters @@ -28,7 +29,9 @@ def levi_civita_tensor(dim): @njit(cache=True) -def _batch_matvec(matrix_collection, vector_collection): +def _batch_matvec( + matrix_collection: NDArray[np.floating], vector_collection: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does batch matrix and batch vector product @@ -59,7 +62,10 @@ def _batch_matvec(matrix_collection, vector_collection): @njit(cache=True) -def _batch_matmul(first_matrix_collection, second_matrix_collection): +def _batch_matmul( + first_matrix_collection: NDArray[np.floating], + second_matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This is batch matrix matrix multiplication function. Only batch of 3x3 matrices can be multiplied. @@ -93,7 +99,10 @@ def _batch_matmul(first_matrix_collection, second_matrix_collection): @njit(cache=True) -def _batch_cross(first_vector_collection, second_vector_collection): +def _batch_cross( + first_vector_collection: NDArray[np.floating], + second_vector_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function does cross product between two batch vectors. @@ -133,7 +142,9 @@ def _batch_cross(first_vector_collection, second_vector_collection): @njit(cache=True) -def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): +def _batch_vec_oneD_vec_cross( + first_vector_collection: NDArray[np.floating], second_vector: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does cross product between batch vector and a 1D vector. Idea of having this function is that, for friction calculations, we dont @@ -177,7 +188,9 @@ def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): @njit(cache=True) -def _batch_dot(first_vector, second_vector): +def _batch_dot( + first_vector: NDArray[np.floating], second_vector: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does batch vec and batch vec dot product. Parameters @@ -204,7 +217,7 @@ def _batch_dot(first_vector, second_vector): @njit(cache=True) -def _batch_norm(vector): +def _batch_norm(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes norm of a batch vector Parameters @@ -233,7 +246,9 @@ def _batch_norm(vector): @njit(cache=True) -def _batch_product_i_k_to_ik(vector1, vector2): +def _batch_product_i_k_to_ik( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does outer product following 'i,k->ik'. vector1 has shape of 3 and vector 2 has shape of blocksize @@ -262,7 +277,9 @@ def _batch_product_i_k_to_ik(vector1, vector2): @njit(cache=True) -def _batch_product_i_ik_to_k(vector1, vector2): +def _batch_product_i_ik_to_k( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does the following product 'i,ik->k' This function do dot product between a vector of 3 elements @@ -293,7 +310,9 @@ def _batch_product_i_ik_to_k(vector1, vector2): @njit(cache=True) -def _batch_product_k_ik_to_ik(vector1, vector2): +def _batch_product_k_ik_to_ik( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does the following product 'k, ik->ik' Parameters @@ -322,7 +341,9 @@ def _batch_product_k_ik_to_ik(vector1, vector2): @njit(cache=True) -def _batch_vector_sum(vector1, vector2): +def _batch_vector_sum( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function is for summing up two vectors. Although this function is not faster than pure python implementation @@ -352,7 +373,7 @@ def _batch_vector_sum(vector1, vector2): @njit(cache=True) -def _batch_matrix_transpose(input_matrix): +def _batch_matrix_transpose(input_matrix: NDArray[np.floating]) -> NDArray[np.floating]: """ This function takes an batch input matrix and transpose it. Parameters diff --git a/elastica/_rotations.py b/elastica/_rotations.py index 25ec1421..7fb9d063 100644 --- a/elastica/_rotations.py +++ b/elastica/_rotations.py @@ -8,6 +8,7 @@ from numpy import cos from numpy import sqrt from numpy import arccos +from numpy.typing import NDArray from numba import njit @@ -15,7 +16,9 @@ @njit(cache=True) -def _get_rotation_matrix(scale: float, axis_collection): +def _get_rotation_matrix( + scale: np.floating, axis_collection: NDArray[np.floating] +) -> NDArray[np.floating]: blocksize = axis_collection.shape[1] rot_mat = np.empty((3, 3, blocksize)) @@ -49,7 +52,11 @@ def _get_rotation_matrix(scale: float, axis_collection): @njit(cache=True) -def _rotate(director_collection, scale: float, axis_collection): +def _rotate( + director_collection: NDArray[np.floating], + scale: np.floating, + axis_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Does alibi rotations https://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities @@ -74,7 +81,7 @@ def _rotate(director_collection, scale: float, axis_collection): @njit(cache=True) -def _inv_rotate(director_collection): +def _inv_rotate(director_collection: NDArray[np.floating]) -> NDArray[np.floating]: """ Calculated rate of change using Rodrigues' formula @@ -156,12 +163,15 @@ def _inv_rotate(director_collection): return vector_collection +_generate_skew_map_sentinel = (0, 0, 0) + + # TODO: Below contains numpy-only implementations @functools.lru_cache(maxsize=1) -def _generate_skew_map(dim: int): +def _generate_skew_map(dim: int) -> list[tuple[int, int, int]]: # TODO Documentation # Preallocate - mapping_list = [None] * ((dim**2 - dim) // 2) + mapping_list = [_generate_skew_map_sentinel] * ((dim**2 - dim) // 2) # Indexing (i,j), j is the fastest changing # r = 2, r here is rank, we deal with only matrices for index, (i, j) in enumerate(combinations(range(dim), r=2)): @@ -185,7 +195,7 @@ def _generate_skew_map(dim: int): @functools.lru_cache(maxsize=1) -def _get_skew_map(dim): +def _get_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: """Generates mapping from src to target skew-symmetric operator For input vector V and output Matrix M (represented in lexicographical index), @@ -208,7 +218,7 @@ def _get_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_inv_skew_map(dim): +def _get_inv_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: # TODO Documentation # (vec_src, mat_i, mat_j, sign) mapping_list = _generate_skew_map(dim) @@ -219,7 +229,7 @@ def _get_inv_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_diag_map(dim): +def _get_diag_map(dim: int) -> tuple[int, ...]: """Generates lexicographic mapping to diagonal in a serialized matrix-type For input dimension dim we calculate mapping to * in Matrix M below @@ -231,17 +241,10 @@ def _get_diag_map(dim): in a dimension agnostic way. """ - # Preallocate - mapping_list = [None] * dim - - # Store linear indices - for dim_iter in range(dim): - mapping_list[dim_iter] = dim_iter * (dim + 1) - - return tuple(mapping_list) + return tuple([dim_iter * (dim + 1) for dim_iter in range(dim)]) -def _skew_symmetrize(vector): +def _skew_symmetrize(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ Parameters @@ -276,7 +279,7 @@ def _skew_symmetrize(vector): # This is purely for testing and optimization sake # While calculating u^2, use u with einsum instead, as it is tad bit faster -def _skew_symmetrize_sq(vector): +def _skew_symmetrize_sq(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ Generate the square of an orthogonal matrix from vector elements @@ -298,12 +301,11 @@ def _skew_symmetrize_sq(vector): hardcoded : 23.1 µs ± 481 ns per loop this version: 14.1 µs ± 96.9 ns per loop """ - dim, _ = vector.shape # First generate array of [x^2, xy, xz, yx, y^2, yz, zx, zy, z^2] # across blocksize # This is slightly faster than doing v[np.newaxis,:,:] * v[:,np.newaxis,:] - products_xy = np.einsum("ik,jk->ijk", vector, vector) + products_xy: NDArray[np.floating] = np.einsum("ik,jk->ijk", vector, vector) # No copy made here, as we do not change memory layout # products_xy = products_xy.reshape((dim * dim, -1)) @@ -335,7 +337,9 @@ def _skew_symmetrize_sq(vector): return products_xy -def _get_skew_symmetric_pair(vector_collection): +def _get_skew_symmetric_pair( + vector_collection: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Parameters @@ -351,7 +355,7 @@ def _get_skew_symmetric_pair(vector_collection): return u, u_sq -def _inv_skew_symmetrize(matrix): +def _inv_skew_symmetrize(matrix: NDArray[np.floating]) -> NDArray[np.floating]: """ Return the vector elements from a skew-symmetric matrix M diff --git a/elastica/_synchronize_periodic_boundary.py b/elastica/_synchronize_periodic_boundary.py index b4fe87b4..0a06622c 100644 --- a/elastica/_synchronize_periodic_boundary.py +++ b/elastica/_synchronize_periodic_boundary.py @@ -2,12 +2,18 @@ """These functions are used to synchronize periodic boundaries for ring rods. """ ) +from typing import Any from numba import njit +import numpy as np +from numpy.typing import NDArray from elastica.boundary_conditions import ConstraintBase +from elastica.typing import SystemType @njit(cache=True) -def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_vector_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a vector collection. Parameters @@ -28,7 +34,9 @@ def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): @njit(cache=True) -def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_matrix_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a matrix collection. Parameters @@ -50,7 +58,9 @@ def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): @njit(cache=True) -def _synchronize_periodic_boundary_of_scalar_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_scalar_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a scalar collection. @@ -76,10 +86,11 @@ class _ConstrainPeriodicBoundaries(ConstraintBase): is to synchronize periodic boundaries of ring rod. """ - def __init__(self, **kwargs): + # TODO: improve typing + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, rod, time): + def constrain_values(self, rod: SystemType, time: np.floating) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.position_collection, rod.periodic_boundary_nodes_idx ) @@ -87,7 +98,7 @@ def constrain_values(self, rod, time): rod.director_collection, rod.periodic_boundary_elems_idx ) - def constrain_rates(self, rod, time): + def constrain_rates(self, rod: SystemType, time: np.floating) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.velocity_collection, rod.periodic_boundary_nodes_idx ) diff --git a/elastica/boundary_conditions.py b/elastica/boundary_conditions.py index 0cf666c8..9203c958 100644 --- a/elastica/boundary_conditions.py +++ b/elastica/boundary_conditions.py @@ -1,9 +1,10 @@ __doc__ = """ Built-in boundary condition implementationss """ import warnings -from typing import Optional +from typing import Any, Optional, Tuple import numpy as np +from numpy.typing import NDArray from abc import ABC, abstractmethod @@ -34,7 +35,7 @@ class ConstraintBase(ABC): _constrained_position_idx: np.ndarray _constrained_director_idx: np.ndarray - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize boundary condition""" try: self._system = kwargs["_system"] @@ -67,7 +68,7 @@ def constrained_director_idx(self) -> Optional[np.ndarray]: return self._constrained_director_idx @abstractmethod - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: np.floating) -> None: # TODO: In the future, we can remove rod and use self.system """ Constrain values (position and/or directors) of a rod object. @@ -82,7 +83,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: pass @abstractmethod - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: np.floating) -> None: # TODO: In the future, we can remove rod and use self.system """ Constrain rates (velocity and/or omega) of a rod object. @@ -103,14 +104,14 @@ class FreeBC(ConstraintBase): Boundary condition template. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: np.floating) -> None: """In FreeBC, this routine simply passes.""" pass - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: np.floating) -> None: """In FreeBC, this routine simply passes.""" pass @@ -143,7 +144,12 @@ class OneEndFixedBC(ConstraintBase): ... ) """ - def __init__(self, fixed_position, fixed_directors, **kwargs): + def __init__( + self, + fixed_position: Tuple[int, ...], + fixed_directors: Tuple[int, ...], + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -159,7 +165,7 @@ def __init__(self, fixed_position, fixed_directors, **kwargs): self.fixed_position_collection = np.array(fixed_position) self.fixed_directors_collection = np.array(fixed_directors) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: np.floating) -> None: # system.position_collection[..., 0] = self.fixed_position # system.director_collection[..., 0] = self.fixed_directors self.compute_constrain_values( @@ -169,7 +175,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.fixed_directors_collection, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: np.floating) -> None: # system.velocity_collection[..., 0] = 0.0 # system.omega_collection[..., 0] = 0.0 self.compute_constrain_rates( @@ -180,11 +186,11 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def compute_constrain_values( - position_collection, - fixed_position_collection, - director_collection, - fixed_directors_collection, - ): + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + director_collection: NDArray[np.floating], + fixed_directors_collection: NDArray[np.floating], + ) -> None: """ Computes constrain values in numba njit decorator @@ -208,7 +214,10 @@ def compute_constrain_values( @staticmethod @njit(cache=True) - def compute_constrain_rates(velocity_collection, omega_collection): + def compute_constrain_rates( + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + ) -> None: """ Compute contrain rates in numba njit decorator @@ -266,11 +275,11 @@ class GeneralConstraint(ConstraintBase): def __init__( self, - *fixed_data, - translational_constraint_selector: Optional[np.ndarray] = None, - rotational_constraint_selector: Optional[np.array] = None, - **kwargs, - ): + *fixed_data: Any, + translational_constraint_selector: Optional[NDArray[np.bool_]] = None, + rotational_constraint_selector: Optional[NDArray[np.bool_]] = None, + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -331,7 +340,7 @@ def __init__( ) self.rotational_constraint_selector = rotational_constraint_selector.astype(int) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: np.floating) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -340,7 +349,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.translational_constraint_selector, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: np.floating) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -358,7 +367,10 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices, constraint_selector + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -393,7 +405,9 @@ def nb_constrain_translational_values( @staticmethod @njit(cache=True) def nb_constrain_translational_rates( - velocity_collection, indices, constraint_selector + velocity_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Compute constrain rates in numba njit decorator @@ -422,7 +436,10 @@ def nb_constrain_translational_rates( @staticmethod @njit(cache=True) def nb_constrain_rotational_rates( - director_collection, omega_collection, indices, constraint_selector + director_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Compute constrain rates in numba njit decorator @@ -489,7 +506,7 @@ class FixedConstraint(GeneralConstraint): GeneralConstraint: Generalized constraint with configurable DOF. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -508,7 +525,7 @@ def __init__(self, *args, **kwargs): **kwargs, ) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: np.floating) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -522,7 +539,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.constrained_director_idx, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: np.floating) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -537,7 +554,9 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def nb_constraint_rotational_values( - director_collection, fixed_director_collection, indices + director_collection: NDArray[np.floating], + fixed_director_collection: NDArray[np.floating], + indices: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -558,7 +577,9 @@ def nb_constraint_rotational_values( @staticmethod @njit(cache=True) def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + indices: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -578,7 +599,9 @@ def nb_constrain_translational_values( @staticmethod @njit(cache=True) - def nb_constrain_translational_rates(velocity_collection, indices) -> None: + def nb_constrain_translational_rates( + velocity_collection: NDArray[np.floating], indices: NDArray[np.integer] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -598,7 +621,9 @@ def nb_constrain_translational_rates(velocity_collection, indices) -> None: @staticmethod @njit(cache=True) - def nb_constrain_rotational_rates(omega_collection, indices) -> None: + def nb_constrain_rotational_rates( + omega_collection: NDArray[np.floating], indices: NDArray[np.integer] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -654,15 +679,15 @@ class HelicalBucklingBC(ConstraintBase): def __init__( self, - position_start: np.ndarray, - position_end: np.ndarray, - director_start: np.ndarray, - director_end: np.ndarray, - twisting_time: float, - slack: float, - number_of_rotations: float, - **kwargs, - ): + position_start: NDArray[np.floating], + position_end: NDArray[np.floating], + director_start: NDArray[np.floating], + director_end: NDArray[np.floating], + twisting_time: np.floating, + slack: np.floating, + number_of_rotations: np.floating, + **kwargs: Any, + ) -> None: """ Helical Buckling initializer @@ -718,7 +743,7 @@ def __init__( @ director_end ) # rotation_matrix wants vectors 3,1 - def constrain_values(self, rod: RodType, time: float) -> None: + def constrain_values(self, rod: RodType, time: np.floating) -> None: if time > self.twisting_time: rod.position_collection[..., 0] = self.final_start_position rod.position_collection[..., -1] = self.final_end_position @@ -726,7 +751,7 @@ def constrain_values(self, rod: RodType, time: float) -> None: rod.director_collection[..., 0] = self.final_start_directors rod.director_collection[..., -1] = self.final_end_directors - def constrain_rates(self, rod: RodType, time: float) -> None: + def constrain_rates(self, rod: RodType, time: np.floating) -> None: if time > self.twisting_time: rod.velocity_collection[..., 0] = 0.0 rod.omega_collection[..., 0] = 0.0 diff --git a/elastica/callback_functions.py b/elastica/callback_functions.py index ab865dfd..cfd28068 100644 --- a/elastica/callback_functions.py +++ b/elastica/callback_functions.py @@ -4,9 +4,12 @@ import sys import numpy as np import logging +from typing import Any, Optional from collections import defaultdict +from elastica.typing import RodType, SystemType + class CallBackBaseClass: """ @@ -19,13 +22,13 @@ class CallBackBaseClass: """ - def __init__(self): + def __init__(self) -> None: """ CallBackBaseClass does not need any input parameters. """ pass - def make_callback(self, system, time, current_step: int): + def make_callback(self, syste: RodType, time: np.floating, current_step: int) -> None: """ This method is called every time step. Users can define which parameters are called back and recorded. Also users @@ -59,7 +62,7 @@ class MyCallBack(CallBackBaseClass): Collected callback data is saved in this dictionary. """ - def __init__(self, step_skip: int, callback_params): + def __init__(self, step_skip: int, callback_params: dict) -> None: """ Parameters @@ -73,7 +76,7 @@ def __init__(self, step_skip: int, callback_params): self.sample_every = step_skip self.callback_params = callback_params - def make_callback(self, system, time, current_step: int): + def make_callback(self, system: SystemType, time: np.floating, current_step: int) -> None: if current_step % self.sample_every == 0: @@ -116,8 +119,8 @@ def __init__( directory: str, method: str, initial_file_count: int = 0, - file_save_interval: int = 1e8, - ): + file_save_interval: int = 100_000_000, + ) -> None: """ Parameters ---------- @@ -189,7 +192,7 @@ def __init__( self._pickle = pickle self._ext = "pkl" - def make_callback(self, system, time, current_step: int): + def make_callback(self, system: SystemType, time: np.floating, current_step: int) -> None: """ Parameters @@ -224,7 +227,7 @@ def make_callback(self, system, time, current_step: int): ): self._dump() - def _dump(self, **kwargs): + def _dump(self, **kwargs: Any) -> None: """ Dump dictionary buffer (self.buffer) to a file and clear the buffer. @@ -247,7 +250,7 @@ def _dump(self, **kwargs): self.buffer_size = 0 self.buffer.clear() - def get_last_saved_path(self) -> str: + def get_last_saved_path(self) -> Optional[str]: """ Return last saved file path. If no file has been saved, return None @@ -257,14 +260,14 @@ def get_last_saved_path(self) -> str: else: return self.save_path.format(self.file_count - 1, self._ext) - def close(self): + def close(self) -> None: """ Save residual buffer """ if self.buffer_size: self._dump() - def clear(self): + def clear(self) -> None: """ Alias to `close` """ diff --git a/elastica/contact_forces.py b/elastica/contact_forces.py index 8f9b0ab5..9e9f5a2d 100644 --- a/elastica/contact_forces.py +++ b/elastica/contact_forces.py @@ -1,5 +1,6 @@ __doc__ = """ Numba implementation module containing contact between rods and rigid bodies and other rods rigid bodies or surfaces.""" +from typing import Optional from elastica.typing import RodType, SystemType, AllowedContactType from elastica.rod import RodBase from elastica.rigidbody import Cylinder, Sphere @@ -19,6 +20,7 @@ _calculate_contact_forces_cylinder_plane, ) import numpy as np +from numpy.typing import NDArray class NoContact: @@ -32,7 +34,7 @@ class NoContact: """ - def __init__(self): + def __init__(self) -> None: """ NoContact class does not need any input parameters. """ @@ -69,7 +71,7 @@ def apply_contact( self, system_one: SystemType, system_two: AllowedContactType, - ) -> None: + ) -> Optional[tuple[NDArray[np.floating], NDArray[np.intp]]]: """ Apply contact forces and torques between SystemType object and AllowedContactType object. @@ -101,7 +103,7 @@ class RodRodContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: np.floating, nu: np.floating) -> None: """ Parameters ---------- @@ -225,11 +227,11 @@ class RodCylinderContact(NoContact): def __init__( self, - k: float, - nu: float, - velocity_damping_coefficient=0.0, - friction_coefficient=0.0, - ): + k: np.floating, + nu: np.floating, + velocity_damping_coefficient: np.floating = 0.0, + friction_coefficient: np.floating = 0.0, + ) -> None: """ Parameters @@ -338,7 +340,7 @@ class RodSelfContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: np.floating, nu: np.floating) -> None: """ Parameters @@ -435,11 +437,11 @@ class RodSphereContact(NoContact): def __init__( self, - k: float, - nu: float, - velocity_damping_coefficient=0.0, - friction_coefficient=0.0, - ): + k: np.floating, + nu: np.floating, + velocity_damping_coefficient: np.floating = 0.0, + friction_coefficient: np.floating = 0.0, + ) -> None: """ Parameters ---------- @@ -560,9 +562,9 @@ class RodPlaneContact(NoContact): def __init__( self, - k: float, - nu: float, - ): + k: np.floating, + nu: np.floating, + ) -> None: """ Parameters ---------- @@ -652,12 +654,12 @@ class RodPlaneContactWithAnisotropicFriction(NoContact): def __init__( self, - k: float, - nu: float, - slip_velocity_tol: float, - static_mu_array: np.ndarray, - kinetic_mu_array: np.ndarray, - ): + k: np.floating, + nu: np.floating, + slip_velocity_tol: np.floating, + static_mu_array: NDArray[np.floating], + kinetic_mu_array: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -776,9 +778,9 @@ class CylinderPlaneContact(NoContact): def __init__( self, - k: float, - nu: float, - ): + k: np.floating, + nu: np.floating, + ) -> None: """ Parameters ---------- @@ -818,7 +820,9 @@ def _check_systems_validity( ) ) - def apply_contact(self, system_one: Cylinder, system_two: SystemType): + def apply_contact( + self, system_one: Cylinder, system_two: SystemType + ) -> tuple[NDArray[np.floating], NDArray[np.intp]]: """ This function computes the plane force response on the cylinder, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper diff --git a/elastica/contact_utils.py b/elastica/contact_utils.py index 71584d2b..657b6dd0 100644 --- a/elastica/contact_utils.py +++ b/elastica/contact_utils.py @@ -3,37 +3,47 @@ from math import sqrt import numba import numpy as np +from numpy.typing import NDArray from elastica._linalg import ( _batch_norm, ) +from typing import Literal, Sequence, TypeVar @numba.njit(cache=True) -def _dot_product(a, b): - sum = 0.0 +def _dot_product(a: Sequence[np.floating], b: Sequence[np.floating]) -> np.floating: + sum: np.floating = 0.0 for i in range(3): sum += a[i] * b[i] return sum @numba.njit(cache=True) -def _norm(a): +def _norm(a: Sequence[np.floating]) -> float: return sqrt(_dot_product(a, a)) +_SupportsCompareT = TypeVar("_SupportsCompareT") + + @numba.njit(cache=True) -def _clip(x, low, high): +def _clip(x: np.floating, low: np.floating, high: np.floating) -> np.floating: return max(low, min(x, high)) # Can this be made more efficient than 2 comp, 1 or? @numba.njit(cache=True) -def _out_of_bounds(x, low, high): +def _out_of_bounds(x: np.floating, low: np.floating, high: np.floating) -> bool: return (x < low) or (x > high) @numba.njit(cache=True) -def _find_min_dist(x1, e1, x2, e2): +def _find_min_dist( + x1: NDArray[np.floating], + e1: NDArray[np.floating], + x2: NDArray[np.floating], + e2: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: e1e1 = _dot_product(e1, e1) e1e2 = _dot_product(e1, e2) e2e2 = _dot_product(e2, e2) @@ -99,7 +109,9 @@ def _find_min_dist(x1, e1, x2, e2): @numba.njit(cache=True) -def _aabbs_not_intersecting(aabb_one, aabb_two): +def _aabbs_not_intersecting( + aabb_one: NDArray[np.floating], aabb_two: NDArray[np.floating] +) -> Literal[1, 0]: """Returns true if not intersecting else false""" if (aabb_one[0, 1] < aabb_two[0, 0]) | (aabb_one[0, 0] > aabb_two[0, 1]): return 1 @@ -113,14 +125,14 @@ def _aabbs_not_intersecting(aabb_one, aabb_two): @numba.njit(cache=True) def _prune_using_aabbs_rod_cylinder( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - cylinder_position, - cylinder_director, - cylinder_radius, - cylinder_length, -): + rod_one_position_collection: NDArray[np.floating], + rod_one_radius_collection: NDArray[np.floating], + rod_one_length_collection: NDArray[np.floating], + cylinder_position: NDArray[np.floating], + cylinder_director: NDArray[np.floating], + cylinder_radius: NDArray[np.floating], + cylinder_length: NDArray[np.floating], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_cylinder = np.empty((3, 2)) @@ -155,13 +167,13 @@ def _prune_using_aabbs_rod_cylinder( @numba.njit(cache=True) def _prune_using_aabbs_rod_rod( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - rod_two_position_collection, - rod_two_radius_collection, - rod_two_length_collection, -): + rod_one_position_collection: NDArray[np.floating], + rod_one_radius_collection: NDArray[np.floating], + rod_one_length_collection: NDArray[np.floating], + rod_two_position_collection: NDArray[np.floating], + rod_two_radius_collection: NDArray[np.floating], + rod_two_length_collection: NDArray[np.floating], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod_one = np.empty((3, 2)) aabb_rod_two = np.empty((3, 2)) @@ -193,13 +205,13 @@ def _prune_using_aabbs_rod_rod( @numba.njit(cache=True) def _prune_using_aabbs_rod_sphere( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - sphere_position, - sphere_director, - sphere_radius, -): + rod_one_position_collection: NDArray[np.floating], + rod_one_radius_collection: NDArray[np.floating], + rod_one_length_collection: NDArray[np.floating], + sphere_position: NDArray[np.floating], + sphere_director: NDArray[np.floating], + sphere_radius: NDArray[np.floating], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_sphere = np.empty((3, 2)) @@ -231,7 +243,9 @@ def _prune_using_aabbs_rod_sphere( @numba.njit(cache=True) -def _find_slipping_elements(velocity_slip, velocity_threshold): +def _find_slipping_elements( + velocity_slip: NDArray[np.floating], velocity_threshold: np.floating +) -> NDArray[np.floating]: """ This function takes the velocity of elements and checks if they are larger than the threshold velocity. If the velocity of elements is larger than threshold velocity, that means those elements are slipping. @@ -272,7 +286,7 @@ def _find_slipping_elements(velocity_slip, velocity_threshold): @numba.njit(cache=True) -def _node_to_element_mass_or_force(input): +def _node_to_element_mass_or_force(input: NDArray[np.floating]) -> NDArray[np.floating]: """ This function converts the mass/forces on rod nodes to elements, where special treatment is necessary at the ends. @@ -310,7 +324,10 @@ def _node_to_element_mass_or_force(input): @numba.njit(cache=True) -def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): +def _elements_to_nodes_inplace( + vector_in_element_frame: NDArray[np.floating], + vector_in_node_frame: NDArray[np.floating], +) -> None: """ Updating nodal forces using the forces computed on elements Parameters @@ -333,7 +350,9 @@ def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): @numba.njit(cache=True) -def _node_to_element_position(node_position_collection): +def _node_to_element_position( + node_position_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function computes the position of the elements from the nodal values. @@ -379,7 +398,9 @@ def _node_to_element_position(node_position_collection): @numba.njit(cache=True) -def _node_to_element_velocity(mass, node_velocity_collection): +def _node_to_element_velocity( + mass: NDArray[np.floating], node_velocity_collection: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function computes the velocity of the elements from the nodal values. Uses the velocity of center of mass diff --git a/elastica/dissipation.py b/elastica/dissipation.py index f3629fe7..b8ddc95e 100644 --- a/elastica/dissipation.py +++ b/elastica/dissipation.py @@ -5,12 +5,14 @@ """ from abc import ABC, abstractmethod +from typing import Any from elastica.typing import RodType, SystemType from numba import njit import numpy as np +from numpy.typing import NDArray class DamperBase(ABC): @@ -29,7 +31,8 @@ class DamperBase(ABC): _system: SystemType - def __init__(self, *args, **kwargs): + # TODO typing can be made better + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize damping module""" try: self._system = kwargs["_system"] @@ -40,7 +43,7 @@ def __init__(self, *args, **kwargs): ) @property - def system(self): # -> SystemType: (Return type is not parsed with sphinx book.) + def system(self) -> SystemType: """ get system (rod or rigid body) reference @@ -52,7 +55,7 @@ def system(self): # -> SystemType: (Return type is not parsed with sphinx book. return self._system @abstractmethod - def dampen_rates(self, system: SystemType, time: float): + def dampen_rates(self, system: SystemType, time: np.floating) -> None: # TODO: In the future, we can remove rod and use self.system """ Dampen rates (velocity and/or omega) of a rod object. @@ -113,7 +116,9 @@ class AnalyticalLinearDamper(DamperBase): Damping coefficient acting on rotational velocity. """ - def __init__(self, damping_constant, time_step, **kwargs): + def __init__( + self, damping_constant: np.floating, time_step: np.floating, **kwargs: Any + ) -> None: """ Analytical linear damper initializer @@ -143,7 +148,7 @@ def __init__(self, damping_constant, time_step, **kwargs): * np.diagonal(self._system.inv_mass_second_moment_of_inertia).T ) - def dampen_rates(self, rod: RodType, time: float): + def dampen_rates(self, rod: RodType, time: np.floating) -> None: rod.velocity_collection[:] = ( rod.velocity_collection * self.translational_damping_coefficient ) @@ -202,7 +207,7 @@ class LaplaceDissipationFilter(DamperBase): Filter term that modifies rod rotational velocity. """ - def __init__(self, filter_order: int, **kwargs): + def __init__(self, filter_order: int, **kwargs: Any) -> None: """ Filter damper initializer @@ -232,7 +237,7 @@ def __init__(self, filter_order: int, **kwargs): self.omega_filter_term = np.zeros_like(self._system.omega_collection) self.filter_function = _filter_function_periodic_condition - def dampen_rates(self, rod: RodType, time: float) -> None: + def dampen_rates(self, rod: RodType, time: np.floating) -> None: self.filter_function( rod.velocity_collection, @@ -245,12 +250,12 @@ def dampen_rates(self, rod: RodType, time: float) -> None: @njit(cache=True) def _filter_function_periodic_condition_ring_rod( - velocity_collection, - velocity_filter_term, - omega_collection, - omega_filter_term, - filter_order, -): + velocity_collection: NDArray[np.floating], + velocity_filter_term: NDArray[np.floating], + omega_collection: NDArray[np.floating], + omega_filter_term: NDArray[np.floating], + filter_order: int, +) -> None: blocksize = velocity_filter_term.shape[1] # Transfer velocity to an array which has periodic boundaries and synchornize boundaries @@ -283,12 +288,12 @@ def _filter_function_periodic_condition_ring_rod( @njit(cache=True) def _filter_function_periodic_condition( - velocity_collection, - velocity_filter_term, - omega_collection, - omega_filter_term, - filter_order, -): + velocity_collection: NDArray[np.floating], + velocity_filter_term: NDArray[np.floating], + omega_collection: NDArray[np.floating], + omega_filter_term: NDArray[np.floating], + filter_order: int, +) -> None: nb_filter_rate( rate_collection=velocity_collection, filter_term=velocity_filter_term, @@ -303,7 +308,9 @@ def _filter_function_periodic_condition( @njit(cache=True) def nb_filter_rate( - rate_collection: np.ndarray, filter_term: np.ndarray, filter_order: int + rate_collection: NDArray[np.floating], + filter_term: NDArray[np.floating], + filter_order: int, ) -> None: """ Filters the rod rates (velocities) in numba njit decorator diff --git a/elastica/external_forces.py b/elastica/external_forces.py index cb9c61e6..ccf25764 100644 --- a/elastica/external_forces.py +++ b/elastica/external_forces.py @@ -3,6 +3,8 @@ import numpy as np +from numpy.typing import NDArray + from elastica._linalg import _batch_matvec from elastica.typing import SystemType, RodType from elastica.utils import _bspline @@ -22,13 +24,13 @@ class NoForces: """ - def __init__(self): + def __init__(self) -> None: """ NoForces class does not need any input parameters. """ pass - def apply_forces(self, system: SystemType, time: np.float64 = 0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: """Apply forces to a rod-like object. In NoForces class, this routine simply passes. @@ -43,7 +45,7 @@ def apply_forces(self, system: SystemType, time: np.float64 = 0.0): """ pass - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques(self, system: SystemType, time: np.floating = 0.0) -> None: """Apply torques to a rod-like object. In NoForces class, this routine simply passes. @@ -70,7 +72,9 @@ class GravityForces(NoForces): """ - def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): + def __init__( + self, acc_gravity: NDArray[np.floating] = np.array([0.0, -9.80665, 0.0]) + ) -> None: """ Parameters @@ -82,14 +86,18 @@ def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): super(GravityForces, self).__init__() self.acc_gravity = acc_gravity - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: self.compute_gravity_forces( self.acc_gravity, system.mass, system.external_forces ) @staticmethod @njit(cache=True) - def compute_gravity_forces(acc_gravity, mass, external_forces): + def compute_gravity_forces( + acc_gravity: NDArray[np.floating], + mass: NDArray[np.floating], + external_forces: NDArray[np.floating], + ) -> None: """ This function add gravitational forces on the nodes. We are using njit decorated function to increase the speed. @@ -122,7 +130,12 @@ class EndpointForces(NoForces): """ - def __init__(self, start_force, end_force, ramp_up_time): + def __init__( + self, + start_force: NDArray[np.floating], + end_force: NDArray[np.floating], + ramp_up_time: np.floating, + ) -> None: """ Parameters @@ -143,7 +156,7 @@ def __init__(self, start_force, end_force, ramp_up_time): assert ramp_up_time > 0.0 self.ramp_up_time = ramp_up_time - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: self.compute_end_point_forces( system.external_forces, self.start_force, @@ -155,8 +168,12 @@ def apply_forces(self, system: SystemType, time=0.0): @staticmethod @njit(cache=True) def compute_end_point_forces( - external_forces, start_force, end_force, time, ramp_up_time - ): + external_forces: NDArray[np.floating], + start_force: NDArray[np.floating], + end_force: NDArray[np.floating], + time: np.floating, + ramp_up_time: np.floating, + ) -> None: """ Compute end point forces that are applied on the rod using numba njit decorator. @@ -174,7 +191,7 @@ def compute_end_point_forces( Applied forces are ramped up until ramp up time. """ - factor = min(1.0, time / ramp_up_time) + factor: np.floating = min(1.0, time / ramp_up_time) external_forces[..., 0] += start_force * factor external_forces[..., -1] += end_force * factor @@ -190,7 +207,11 @@ class UniformTorques(NoForces): """ - def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, + torque: np.floating, + direction: NDArray[np.floating] = np.array([0.0, 0.0, 0.0]), + ) -> None: """ Parameters @@ -204,7 +225,7 @@ def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): super(UniformTorques, self).__init__() self.torque = torque * direction - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques(self, system: SystemType, time: np.floating = 0.0) -> None: n_elems = system.n_elems torque_on_one_element = ( _batch_product_i_k_to_ik(self.torque, np.ones((n_elems))) / n_elems @@ -224,7 +245,11 @@ class UniformForces(NoForces): 2D (dim, 1) array containing data with 'float' type. Total force applied to a rod-like object. """ - def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, + force: np.floating, + direction: NDArray[np.floating] = np.array([0.0, 0.0, 0.0]), + ) -> None: """ Parameters @@ -238,7 +263,7 @@ def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): super(UniformForces, self).__init__() self.force = (force * direction).reshape(3, 1) - def apply_forces(self, rod: RodType, time: np.float64 = 0.0): + def apply_forces(self, rod: RodType, time: np.floating = 0.0) -> None: force_on_one_element = self.force / rod.n_elems rod.external_forces += force_on_one_element @@ -275,16 +300,16 @@ class MuscleTorques(NoForces): def __init__( self, - base_length, - b_coeff, - period, - wave_number, - phase_shift, - direction, - rest_lengths, - ramp_up_time, - with_spline=False, - ): + base_length: np.floating, + b_coeff: NDArray[np.floating], + period: np.floating, + wave_number: np.floating, + phase_shift: np.floating, + direction: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + ramp_up_time: np.floating, + with_spline: bool = False, + ) -> None: """ Parameters @@ -335,7 +360,7 @@ def __init__( else: self.my_spline = np.full_like(self.s, fill_value=1.0) - def apply_torques(self, rod: RodType, time: np.float64 = 0.0): + def apply_torques(self, rod: RodType, time: np.floating = 0.0) -> None: self.compute_muscle_torques( time, self.my_spline, @@ -352,19 +377,19 @@ def apply_torques(self, rod: RodType, time: np.float64 = 0.0): @staticmethod @njit(cache=True) def compute_muscle_torques( - time, - my_spline, - s, - angular_frequency, - wave_number, - phase_shift, - ramp_up_time, - direction, - director_collection, - external_torques, - ): + time: np.floating, + my_spline: NDArray[np.floating], + s: np.floating, + angular_frequency: np.floating, + wave_number: np.floating, + phase_shift: np.floating, + ramp_up_time: np.floating, + direction: NDArray[np.floating], + director_collection: NDArray[np.floating], + external_torques: NDArray[np.floating], + ) -> None: # Ramp up the muscle torque - factor = min(1.0, time / ramp_up_time) + factor: np.floating = min(1.0, time / ramp_up_time) # From the node 1 to node nelem-1 # Magnitude of the torque. Am = beta(s) * sin(2pi*t/T + 2pi*s/lambda + phi) # There is an inconsistency with paper and Elastica cpp implementation. In paper sign in @@ -388,7 +413,10 @@ def compute_muscle_torques( @njit(cache=True) -def inplace_addition(external_force_or_torque, force_or_torque): +def inplace_addition( + external_force_or_torque: NDArray[np.floating], + force_or_torque: NDArray[np.floating], +) -> None: """ This function does inplace addition. First argument `external_force_or_torque` is the system.external_forces @@ -411,7 +439,10 @@ def inplace_addition(external_force_or_torque, force_or_torque): @njit(cache=True) -def inplace_substraction(external_force_or_torque, force_or_torque): +def inplace_substraction( + external_force_or_torque: NDArray[np.floating], + force_or_torque: NDArray[np.floating], +) -> None: """ This function does inplace substraction. First argument `external_force_or_torque` is the system.external_forces @@ -460,12 +491,12 @@ class EndpointForcesSinusoidal(NoForces): def __init__( self, - start_force_mag, - end_force_mag, - ramp_up_time=0.0, - tangent_direction=np.array([0, 0, 1]), - normal_direction=np.array([0, 1, 0]), - ): + start_force_mag: np.floating, + end_force_mag: np.floating, + ramp_up_time: np.floating = 0.0, + tangent_direction: NDArray[np.floating] = np.array([0, 0, 1]), + normal_direction: NDArray[np.floating] = np.array([0, 1, 0]), + ) -> None: """ Parameters @@ -495,7 +526,7 @@ def __init__( assert ramp_up_time >= 0.0 self.ramp_up_time = ramp_up_time - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: if time < self.ramp_up_time: # When time smaller than ramp up time apply the force in normal direction diff --git a/elastica/interaction.py b/elastica/interaction.py index 95d4602e..d695e615 100644 --- a/elastica/interaction.py +++ b/elastica/interaction.py @@ -1,6 +1,7 @@ __doc__ = """ Numba implementation module containing interactions between a rod and its environment.""" +from typing import Any, NoReturn import numpy as np from elastica.external_forces import NoForces from numba import njit @@ -14,8 +15,12 @@ _calculate_contact_forces_cylinder_plane, ) +from numpy.typing import NDArray -def find_slipping_elements(velocity_slip, velocity_threshold): +from elastica.typing import SystemType + + +def find_slipping_elements(velocity_slip: Any, velocity_threshold: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._find_slipping_elements()\n" @@ -23,7 +28,7 @@ def find_slipping_elements(velocity_slip, velocity_threshold): ) -def node_to_element_mass_or_force(input): +def node_to_element_mass_or_force(input: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._node_to_element_mass_or_force()\n" @@ -31,7 +36,7 @@ def node_to_element_mass_or_force(input): ) -def nodes_to_elements(input): +def nodes_to_elements(input: Any) -> NoReturn: # Remove the function beyond v0.4.0 raise NotImplementedError( "This function is removed in v0.3.1. Please use\n" @@ -41,7 +46,9 @@ def nodes_to_elements(input): @njit(cache=True) -def elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): +def elements_to_nodes_inplace( + vector_in_element_frame: Any, vector_in_node_frame: Any +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._elements_to_nodes_inplace()\n" @@ -74,7 +81,13 @@ class InteractionPlane: """ - def __init__(self, k, nu, plane_origin, plane_normal): + def __init__( + self, + k: np.floating, + nu: np.floating, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + ) -> None: """ Parameters @@ -96,7 +109,9 @@ def __init__(self, k, nu, plane_origin, plane_normal): self.plane_normal = plane_normal.reshape(3) self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_normal_force( + self, system: SystemType + ) -> tuple[NDArray[np.floating], NDArray[np.intp]]: """ In the case of contact with the plane, this function computes the plane reaction force on the element. @@ -130,18 +145,18 @@ def apply_normal_force(self, system): def apply_normal_force_numba( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - radius, - mass, - position_collection, - velocity_collection, - internal_forces, - external_forces, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + k: Any, + nu: Any, + radius: Any, + mass: Any, + position_collection: Any, + velocity_collection: Any, + internal_forces: Any, + external_forces: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For rod plane contact please use: \n" "elastica._contact_functions._calculate_contact_forces_rod_plane() \n" @@ -186,14 +201,14 @@ class AnisotropicFrictionalPlane(NoForces, InteractionPlane): def __init__( self, - k, - nu, - plane_origin, - plane_normal, - slip_velocity_tol, - static_mu_array, - kinetic_mu_array, - ): + k: np.floating, + nu: np.floating, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + slip_velocity_tol: np.floating, + static_mu_array: NDArray[np.floating], + kinetic_mu_array: NDArray[np.floating], + ) -> None: """ Parameters @@ -232,7 +247,7 @@ def __init__( # kinetic and static friction should separate functions # for now putting them together to figure out common variables - def apply_forces(self, system, time=0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: """ Call numba implementation to apply friction forces Parameters @@ -269,30 +284,30 @@ def apply_forces(self, system, time=0.0): def anisotropic_friction( - plane_origin, - plane_normal, - surface_tol, - slip_velocity_tol, - k, - nu, - kinetic_mu_forward, - kinetic_mu_backward, - kinetic_mu_sideways, - static_mu_forward, - static_mu_backward, - static_mu_sideways, - radius, - mass, - tangents, - position_collection, - director_collection, - velocity_collection, - omega_collection, - internal_forces, - external_forces, - internal_torques, - external_torques, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + slip_velocity_tol: Any, + k: Any, + nu: Any, + kinetic_mu_forward: Any, + kinetic_mu_backward: Any, + kinetic_mu_sideways: Any, + static_mu_forward: Any, + static_mu_backward: Any, + static_mu_sideways: Any, + radius: Any, + mass: Any, + tangents: Any, + position_collection: Any, + director_collection: Any, + velocity_collection: Any, + omega_collection: Any, + internal_forces: Any, + external_forces: Any, + internal_torques: Any, + external_torques: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For anisotropic_friction please use: \n" "elastica._contact_functions._calculate_contact_forces_rod_plane_with_anisotropic_friction() \n" @@ -302,7 +317,7 @@ def anisotropic_friction( # Slender body module @njit(cache=True) -def sum_over_elements(input): +def sum_over_elements(input: NDArray[np.floating]) -> np.floating: """ This function sums all elements of the input array. Using a Numba njit decorator shows better performance @@ -334,14 +349,14 @@ def sum_over_elements(input): This version: 513 ns ± 24.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) """ - output = 0.0 + output: np.floating = 0.0 for i in range(input.shape[0]): output += input[i] return output -def node_to_element_position(node_position_collection): +def node_to_element_position(node_position_collection: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For node-to-element_position() interpolation please use: \n" "elastica.contact_utils._node_to_element_position() for rod position \n" @@ -349,7 +364,7 @@ def node_to_element_position(node_position_collection): ) -def node_to_element_velocity(mass, node_velocity_collection): +def node_to_element_velocity(mass: Any, node_velocity_collection: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For node-to-element_velocity() interpolation please use: \n" "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" @@ -357,7 +372,7 @@ def node_to_element_velocity(mass, node_velocity_collection): ) -def node_to_element_pos_or_vel(vector_in_node_frame): +def node_to_element_pos_or_vel(vector_in_node_frame: Any) -> NoReturn: # Remove the function beyond v0.4.0 raise NotImplementedError( "This function is removed in v0.3.0. For node-to-element interpolation please use: \n" @@ -369,8 +384,13 @@ def node_to_element_pos_or_vel(vector_in_node_frame): @njit(cache=True) def slender_body_forces( - tangents, velocity_collection, dynamic_viscosity, lengths, radius, mass -): + tangents: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + dynamic_viscosity: np.floating, + lengths: NDArray[np.floating], + radius: NDArray[np.floating], + mass: NDArray[np.floating], +) -> NDArray[np.floating]: r""" This function computes hydrodynamic forces on a body using slender body theory. The below implementation is from Eq. 4.13 in Gazzola et al. RSoS. (2018). @@ -481,7 +501,7 @@ class SlenderBodyTheory(NoForces): """ - def __init__(self, dynamic_viscosity): + def __init__(self, dynamic_viscosity: np.floating) -> None: """ Parameters @@ -492,7 +512,7 @@ def __init__(self, dynamic_viscosity): super(SlenderBodyTheory, self).__init__() self.dynamic_viscosity = dynamic_viscosity - def apply_forces(self, system, time=0.0): + def apply_forces(self, system: SystemType, time: np.floating = 0.0) -> None: """ This function applies hydrodynamic forces on body using the slender body theory given in @@ -518,14 +538,22 @@ def apply_forces(self, system, time=0.0): # base class for interaction # only applies normal force no friction class InteractionPlaneRigidBody: - def __init__(self, k, nu, plane_origin, plane_normal): + def __init__( + self, + k: np.floating, + nu: np.floating, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + ) -> None: self.k = k self.nu = nu self.plane_origin = plane_origin.reshape(3, 1) self.plane_normal = plane_normal.reshape(3) self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_normal_force( + self, system: SystemType + ) -> tuple[NDArray[np.floating], NDArray[np.intp]]: """ This function computes the plane force response on the rigid body, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -553,16 +581,16 @@ def apply_normal_force(self, system): @njit(cache=True) def apply_normal_force_numba_rigid_body( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + k: Any, + nu: Any, + length: Any, + position_collection: Any, + velocity_collection: Any, + external_forces: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For cylinder plane contact please use: \n" diff --git a/elastica/joint.py b/elastica/joint.py index e37d6f05..0547a1bf 100644 --- a/elastica/joint.py +++ b/elastica/joint.py @@ -1,8 +1,12 @@ __doc__ = """ Module containing joint classes to connect multiple rods together. """ + +from typing import Any, NoReturn, Optional + from elastica._rotations import _inv_rotate from elastica.typing import SystemType, RodType import numpy as np import logging +from numpy.typing import NDArray class FreeJoint: @@ -27,7 +31,7 @@ class FreeJoint: # pass the k and nu for the forces # also the necessary rods for the joint # indices should be 0 or -1, we will provide wrappers for users later - def __init__(self, k, nu): + def __init__(self, k: np.floating, nu: np.floating) -> None: """ Parameters @@ -42,8 +46,12 @@ def __init__(self, k, nu): self.nu = nu def apply_forces( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: SystemType, + index_one: int, + system_two: SystemType, + index_two: int, + ) -> None: """ Apply joint force to the connected rod objects. @@ -81,8 +89,12 @@ def apply_forces( return def apply_torques( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: SystemType, + index_one: int, + system_two: SystemType, + index_two: int, + ) -> None: """ Apply restoring joint torques to the connected rod objects. @@ -127,7 +139,9 @@ class HingeJoint(FreeJoint): """ # TODO: IN WRAPPER COMPUTE THE NORMAL DIRECTION OR ASK USER TO GIVE INPUT, IF NOT THROW ERROR - def __init__(self, k, nu, kt, normal_direction): + def __init__( + self, k: np.floating, nu: np.floating, kt: np.floating, normal_direction: NDArray[np.floating] + ) -> None: """ Parameters @@ -154,19 +168,19 @@ def __init__(self, k, nu, kt, normal_direction): def apply_forces( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # current tangent direction of the `index_two` element of system two system_two_tangent = system_two.director_collection[2, :, index_two] @@ -215,7 +229,14 @@ class FixedJoint(FreeJoint): is enforced. """ - def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): + def __init__( + self, + k: np.floating, + nu: np.floating, + kt: np.floating, + nut: np.floating = 0.0, + rest_rotation_matrix: Optional[NDArray[np.floating]] = None, + ) -> None: """ Parameters @@ -254,19 +275,19 @@ def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): def apply_forces( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # collect directors of systems one and two # note that systems can be either rods or rigid bodies system_one_director = system_one.director_collection[..., index_one] @@ -311,10 +332,10 @@ def apply_torques( def get_relative_rotation_two_systems( system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, -): + index_two: int, +) -> NDArray[np.floating]: """ Compute the relative rotation matrix C_12 between system one and system two at the specified elements. @@ -362,7 +383,7 @@ def get_relative_rotation_two_systems( # everything below this comment should be removed beyond v0.4.0 -def _dot_product(a, b): +def _dot_product(a: Any, b: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._dot_product()\n" @@ -370,7 +391,7 @@ def _dot_product(a, b): ) -def _norm(a): +def _norm(a: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._norm()\n" @@ -378,7 +399,7 @@ def _norm(a): ) -def _clip(x, low, high): +def _clip(x: Any, low: Any, high: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._clip()\n" @@ -386,7 +407,7 @@ def _clip(x, low, high): ) -def _out_of_bounds(x, low, high): +def _out_of_bounds(x: Any, low: Any, high: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._out_of_bounds()\n" @@ -394,7 +415,7 @@ def _out_of_bounds(x, low, high): ) -def _find_min_dist(x1, e1, x2, e2): +def _find_min_dist(x1: Any, e1: Any, x2: Any, e2: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._find_min_dist()\n" @@ -403,25 +424,25 @@ def _find_min_dist(x1, e1, x2, e2): def _calculate_contact_forces_rod_rigid_body( - x_collection_rod, - edge_collection_rod, - x_cylinder_center, - x_cylinder_tip, - edge_cylinder, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_cylinder, - external_torques_cylinder, - cylinder_director_collection, - velocity_rod, - velocity_cylinder, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, -): + x_collection_rod: Any, + edge_collection_rod: Any, + x_cylinder_center: Any, + x_cylinder_tip: Any, + edge_cylinder: Any, + radii_sum: Any, + length_sum: Any, + internal_forces_rod: Any, + external_forces_rod: Any, + external_forces_cylinder: Any, + external_torques_cylinder: Any, + cylinder_director_collection: Any, + velocity_rod: Any, + velocity_cylinder: Any, + contact_k: Any, + contact_nu: Any, + velocity_damping_coefficient: Any, + friction_coefficient: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_rod_cylinder()\n" @@ -430,23 +451,23 @@ def _calculate_contact_forces_rod_rigid_body( def _calculate_contact_forces_rod_rod( - x_collection_rod_one, - radius_rod_one, - length_rod_one, - tangent_rod_one, - velocity_rod_one, - internal_forces_rod_one, - external_forces_rod_one, - x_collection_rod_two, - radius_rod_two, - length_rod_two, - tangent_rod_two, - velocity_rod_two, - internal_forces_rod_two, - external_forces_rod_two, - contact_k, - contact_nu, -): + x_collection_rod_one: Any, + radius_rod_one: Any, + length_rod_one: Any, + tangent_rod_one: Any, + velocity_rod_one: Any, + internal_forces_rod_one: Any, + external_forces_rod_one: Any, + x_collection_rod_two: Any, + radius_rod_two: Any, + length_rod_two: Any, + tangent_rod_two: Any, + velocity_rod_two: Any, + internal_forces_rod_two: Any, + external_forces_rod_two: Any, + contact_k: Any, + contact_nu: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_rod_rod()\n" @@ -455,15 +476,15 @@ def _calculate_contact_forces_rod_rod( def _calculate_contact_forces_self_rod( - x_collection_rod, - radius_rod, - length_rod, - tangent_rod, - velocity_rod, - external_forces_rod, - contact_k, - contact_nu, -): + x_collection_rod: Any, + radius_rod: Any, + length_rod: Any, + tangent_rod: Any, + velocity_rod: Any, + external_forces_rod: Any, + contact_k: Any, + contact_nu: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_self_rod()\n" @@ -471,7 +492,7 @@ def _calculate_contact_forces_self_rod( ) -def _aabbs_not_intersecting(aabb_one, aabb_two): +def _aabbs_not_intersecting(aabb_one: Any, aabb_two: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._aabbs_not_intersecting()\n" @@ -480,14 +501,14 @@ def _aabbs_not_intersecting(aabb_one, aabb_two): def _prune_using_aabbs_rod_rigid_body( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - cylinder_position, - cylinder_director, - cylinder_radius, - cylinder_length, -): + rod_one_position_collection: Any, + rod_one_radius_collection: Any, + rod_one_length_collection: Any, + cylinder_position: Any, + cylinder_director: Any, + cylinder_radius: Any, + cylinder_length: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._prune_using_aabbs_rod_cylinder()\n" @@ -496,13 +517,13 @@ def _prune_using_aabbs_rod_rigid_body( def _prune_using_aabbs_rod_rod( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - rod_two_position_collection, - rod_two_radius_collection, - rod_two_length_collection, -): + rod_one_position_collection: Any, + rod_one_radius_collection: Any, + rod_one_length_collection: Any, + rod_two_position_collection: Any, + rod_two_radius_collection: Any, + rod_two_length_collection: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._prune_using_aabbs_rod_rod()\n" @@ -555,7 +576,13 @@ class ExternalContact(FreeJoint): # potentially dangerous as it does not deal with "end" conditions # correctly. - def __init__(self, k, nu, velocity_damping_coefficient=0, friction_coefficient=0): + def __init__( + self, + k: np.floating, + nu: np.floating, + velocity_damping_coefficient: np.floating = 0, + friction_coefficient: np.floating = 0, + ) -> None: """ Parameters @@ -586,11 +613,11 @@ def __init__(self, k, nu, velocity_damping_coefficient=0, friction_coefficient=0 def apply_forces( self, - rod_one: RodType, - index_one, + rod_one: SystemType, + index_one: int, rod_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # del index_one, index_two from elastica.contact_utils import ( _prune_using_aabbs_rod_cylinder, @@ -693,7 +720,7 @@ class SelfContact(FreeJoint): """ - def __init__(self, k, nu): + def __init__(self, k: np.floating, nu: np.floating) -> None: super().__init__(k, nu) log = logging.getLogger(self.__class__.__name__) log.warning( @@ -705,7 +732,9 @@ def __init__(self, k, nu): "The option to use the SelfContact joint for the rod self contact will be removed in the future (v0.3.3).\n" ) - def apply_forces(self, rod_one: RodType, index_one, rod_two: SystemType, index_two): + def apply_forces( + self, rod_one: SystemType, index_one: int, rod_two: SystemType, index_two: int + ) -> None: # del index_one, index_two from elastica._contact_functions import ( _calculate_contact_forces_self_rod, diff --git a/elastica/restart.py b/elastica/restart.py index 1b5fab26..22375d99 100644 --- a/elastica/restart.py +++ b/elastica/restart.py @@ -5,8 +5,10 @@ from itertools import groupby from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody +from typing import Iterable, Iterator, Any -def all_equal(iterable): + +def all_equal(iterable: Iterable[Any]) -> bool: """ Checks if all elements of list are equal. Parameters @@ -20,11 +22,14 @@ def all_equal(iterable): ---------- https://stackoverflow.com/questions/3844801/check-if-all-elements-in-a-list-are-identical """ - g = groupby(iterable) + g: Iterator[Any] = groupby(iterable) return next(g, True) and not next(g, False) -def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): +# TODO: simulator should have better typing +def save_state( + simulator: Iterable, directory: str = "", time: np.floating = 0.0, verbose: bool = False +) -> None: """ Save state parameters of each rod. TODO : environment list variable is not uniform at the current stage of development. @@ -53,7 +58,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): print("Save complete: {}".format(directory)) -def load_state(simulator, directory: str = "", verbose: bool = False): +# TODO: simulator should have better typing +def load_state( + simulator: Iterable, directory: str = "", verbose: bool = False +) -> float: """ Load the rod-state. Compatibale with 'save_state' method. If the save-file does not exist, it returns error. @@ -72,7 +80,7 @@ def load_state(simulator, directory: str = "", verbose: bool = False): time : float Simulation time of systems when they are saved. """ - time_list = [] # Simulation time of rods when they are saved. + time_list: list[float] = [] # Simulation time of rods when they are saved. for idx, rod in enumerate(simulator): if isinstance(rod, MemoryBlockCosseratRod) or isinstance( rod, MemoryBlockRigidBody diff --git a/elastica/rod/cosserat_rod.py b/elastica/rod/cosserat_rod.py index 051b3751..91dd713c 100644 --- a/elastica/rod/cosserat_rod.py +++ b/elastica/rod/cosserat_rod.py @@ -2,6 +2,7 @@ import numpy as np +from numpy.typing import NDArray import functools import numba from elastica.rod import RodBase @@ -20,18 +21,19 @@ _difference, _average, ) -from typing import Optional +from typing import Any, Optional +from typing_extensions import Self position_difference_kernel = _difference position_average = _average @functools.lru_cache(maxsize=1) -def _get_z_vector(): +def _get_z_vector() -> NDArray[np.floating]: return np.array([0.0, 0.0, 1.0]).reshape(3, -1) -def _compute_sigma_kappa_for_blockstructure(memory_block): +def _compute_sigma_kappa_for_blockstructure(memory_block: object) -> None: """ This function is a wrapper to call functions which computes shear stretch, strain and bending twist and strain. @@ -147,39 +149,39 @@ class CosseratRod(RodBase, KnotTheory): def __init__( self, - n_elements, - position, - velocity, - omega, - acceleration, - angular_acceleration, - directors, - radius, - mass_second_moment_of_inertia, - inv_mass_second_moment_of_inertia, - shear_matrix, - bend_matrix, - density, - volume, - mass, - internal_forces, - internal_torques, - external_forces, - external_torques, - lengths, - rest_lengths, - tangents, - dilatation, - dilatation_rate, - voronoi_dilatation, - rest_voronoi_lengths, - sigma, - kappa, - rest_sigma, - rest_kappa, - internal_stress, - internal_couple, - ring_rod_flag, + n_elements: int, + position: NDArray[np.floating], + velocity: NDArray[np.floating], + omega: NDArray[np.floating], + acceleration: NDArray[np.floating], + angular_acceleration: NDArray[np.floating], + directors: NDArray[np.floating], + radius: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + density: NDArray[np.floating], + volume: NDArray[np.floating], + mass: NDArray[np.floating], + internal_forces: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_forces: NDArray[np.floating], + external_torques: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + sigma: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + ring_rod_flag: bool, ): self.n_elems = n_elements self.position_collection = position @@ -242,17 +244,17 @@ def __init__( def straight_rod( cls, n_elements: int, - start: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + start: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, *, nu: Optional[float] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -390,17 +392,17 @@ def straight_rod( def ring_rod( cls, n_elements: int, - ring_center_position: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + ring_center_position: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, *, nu: Optional[float] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -536,7 +538,7 @@ def ring_rod( rod.REQUISITE_MODULE.append(Constraints) return rod - def compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time: float) -> None: """ Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques @@ -591,7 +593,7 @@ def compute_internal_forces_and_torques(self, time): ) # Interface to time-stepper mixins (Symplectic, Explicit), which calls this method - def update_accelerations(self, time): + def update_accelerations(self, time: float) -> None: """ Updates the acceleration variables @@ -613,12 +615,12 @@ def update_accelerations(self, time): self.dilatation, ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time: np.floating) -> None: _zeroed_out_external_forces_and_torques( self.external_forces, self.external_torques ) - def compute_translational_energy(self): + def compute_translational_energy(self) -> NDArray[np.floating]: """ Compute total translational energy of the rod at the instance. """ @@ -632,7 +634,7 @@ def compute_translational_energy(self): ).sum() ) - def compute_rotational_energy(self): + def compute_rotational_energy(self) -> NDArray[np.floating]: """ Compute total rotational energy of the rod at the instance. """ @@ -642,7 +644,7 @@ def compute_rotational_energy(self): ) return 0.5 * np.einsum("ik,ik->k", self.omega_collection, J_omega_upon_e).sum() - def compute_velocity_center_of_mass(self): + def compute_velocity_center_of_mass(self) -> NDArray[np.floating]: """ Compute velocity center of mass of the rod at the instance. """ @@ -651,7 +653,7 @@ def compute_velocity_center_of_mass(self): return sum_mass_times_velocity / self.mass.sum() - def compute_position_center_of_mass(self): + def compute_position_center_of_mass(self) -> NDArray[np.floating]: """ Compute position center of mass of the rod at the instance. """ @@ -660,7 +662,7 @@ def compute_position_center_of_mass(self): return sum_mass_times_position / self.mass.sum() - def compute_bending_energy(self): + def compute_bending_energy(self) -> NDArray[np.floating]: """ Compute total bending energy of the rod at the instance. """ @@ -676,7 +678,7 @@ def compute_bending_energy(self): ).sum() ) - def compute_shear_energy(self): + def compute_shear_energy(self) -> NDArray[np.floating]: """ Compute total shear energy of the rod at the instance. """ @@ -695,8 +697,12 @@ def compute_shear_energy(self): @numba.njit(cache=True) def _compute_geometry_from_state( - position_collection, volume, lengths, tangents, radius -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], +) -> None: """ Update given . """ @@ -719,16 +725,16 @@ def _compute_geometry_from_state( @numba.njit(cache=True) def _compute_all_dilatations( - position_collection, - volume, - lengths, - tangents, - radius, - dilatation, - rest_lengths, - rest_voronoi_lengths, - voronoi_dilatation, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + dilatation: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], +) -> None: """ Update """ @@ -749,8 +755,12 @@ def _compute_all_dilatations( @numba.njit(cache=True) def _compute_dilatation_rate( - position_collection, velocity_collection, lengths, rest_lengths, dilatation_rate -): + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], +) -> None: """ Update dilatation_rate given position, velocity, length, and rest_length """ @@ -776,18 +786,18 @@ def _compute_dilatation_rate( @numba.njit(cache=True) def _compute_shear_stretch_strains( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], +) -> None: """ Update given . """ @@ -811,21 +821,21 @@ def _compute_shear_stretch_strains( @numba.njit(cache=True) def _compute_internal_shear_stretch_stresses_from_model( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + internal_stress: NDArray[np.floating], +) -> None: """ Update given . @@ -850,7 +860,11 @@ def _compute_internal_shear_stretch_stresses_from_model( @numba.njit(cache=True) -def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, kappa): +def _compute_bending_twist_strains( + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + kappa: NDArray[np.floating], +) -> None: """ Update given . """ @@ -864,13 +878,13 @@ def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, ka @numba.njit(cache=True) def _compute_internal_bending_twist_stresses_from_model( - director_collection, - rest_voronoi_lengths, - internal_couple, - bend_matrix, - kappa, - rest_kappa, -): + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + internal_couple: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_kappa: NDArray[np.floating], +) -> None: """ Upate given . @@ -893,23 +907,23 @@ def _compute_internal_bending_twist_stresses_from_model( @numba.njit(cache=True) def _compute_internal_forces( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, - internal_forces, - ghost_elems_idx, -): + position_collection: NDArray[np.floating], + volume: NDArray[np.floating], + lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + radius: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + dilatation: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + director_collection: NDArray[np.floating], + sigma: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_forces: NDArray[np.floating], + ghost_elems_idx: NDArray[np.floating], +) -> None: """ Update given . """ @@ -954,26 +968,26 @@ def _compute_internal_forces( @numba.njit(cache=True) def _compute_internal_torques( - position_collection, - velocity_collection, - tangents, - lengths, - rest_lengths, - director_collection, - rest_voronoi_lengths, - bend_matrix, - rest_kappa, - kappa, - voronoi_dilatation, - mass_second_moment_of_inertia, - omega_collection, - internal_stress, - internal_couple, - dilatation, - dilatation_rate, - internal_torques, - ghost_voronoi_idx, -): + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + tangents: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + director_collection: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + kappa: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + omega_collection: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + internal_torques: NDArray[np.floating], + ghost_voronoi_idx: NDArray[np.integer], +) -> None: """ Update . """ @@ -1043,16 +1057,16 @@ def _compute_internal_torques( @numba.njit(cache=True) def _update_accelerations( - acceleration_collection, - internal_forces, - external_forces, - mass, - alpha_collection, - inv_mass_second_moment_of_inertia, - internal_torques, - external_torques, - dilatation, -): + acceleration_collection: NDArray[np.floating], + internal_forces: NDArray[np.floating], + external_forces: NDArray[np.floating], + mass: NDArray[np.floating], + alpha_collection: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_torques: NDArray[np.floating], + dilatation: NDArray[np.floating], +) -> None: """ Update given . """ @@ -1077,7 +1091,9 @@ def _update_accelerations( @numba.njit(cache=True) -def _zeroed_out_external_forces_and_torques(external_forces, external_torques): +def _zeroed_out_external_forces_and_torques( + external_forces: NDArray[np.floating], external_torques: NDArray[np.floating] +) -> None: """ This function is to zeroed out external forces and torques. diff --git a/elastica/rod/data_structures.py b/elastica/rod/data_structures.py index 3c8b4032..c65a9c4b 100644 --- a/elastica/rod/data_structures.py +++ b/elastica/rod/data_structures.py @@ -1,6 +1,9 @@ __doc__ = "Data structure wrapper for rod components" +from typing import Any, Optional +from typing_extensions import Self import numpy as np +from numpy.typing import NDArray from numba import njit from elastica._rotations import _get_rotation_matrix, _rotate from elastica._linalg import _batch_matmul @@ -9,7 +12,7 @@ # FIXME : Explicit Stepper doesn't work as States lose the # views they initially had when working with a timestepper. # class _RodExplicitStepperMixin: -# def __init__(self): +# def __init__(self) -> None: # ( # self.state, # self.__deriv_state, @@ -43,7 +46,7 @@ class _RodSymplecticStepperMixin: - def __init__(self): + def __init__(self) -> None: self.kinematic_states = _KinematicState( self.position_collection, self.director_collection ) @@ -60,18 +63,40 @@ def __init__(self): # is another function self.kinematic_rates = self.dynamic_states.kinematic_rates - def update_internal_forces_and_torques(self, time, *args, **kwargs): + def update_internal_forces_and_torques( + self, time: np.floating, *args: Any, **kwargs: Any + ) -> None: self.compute_internal_forces_and_torques(time) - def dynamic_rates(self, time, prefac, *args, **kwargs): + def dynamic_rates( + self, time: np.floating, prefac: np.floating, *args: Any, **kwargs: Any + ) -> NDArray[np.floating]: self.update_accelerations(time) return self.dynamic_states.dynamic_rates(time, prefac, *args, **kwargs) - def reset_external_forces_and_torques(self, time, *args, **kwargs): + def reset_external_forces_and_torques( + self, time: np.floating, *args: Any, **kwargs: Any + ) -> None: self.zeroed_out_external_forces_and_torques(time) -def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_states): +def _bootstrap_from_data( + stepper_type: str, + n_elems: int, + vector_states: NDArray[np.floating], + matrix_states: NDArray[np.floating], +) -> Optional[ + tuple[ + "_State", + "_DerivativeState", + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + ] +]: """Returns states wrapping numpy arrays based on the time-stepping algorithm Convenience method that takes in rod internal (raw np.ndarray) data, create views @@ -119,7 +144,7 @@ def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_ # ) raise NotImplementedError else: - return + return None n_velocity_end = n_nodes + n_nodes velocity = np.ndarray.view(vector_states[..., n_nodes:n_velocity_end]) @@ -154,10 +179,10 @@ class _State: def __init__( self, n_elems: int, - position_collection_view, - director_collection_view, - kinematic_rate_collection_view, - ): + position_collection_view: NDArray[np.floating], + director_collection_view: NDArray[np.floating], + kinematic_rate_collection_view: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -173,7 +198,7 @@ def __init__( self.director_collection = director_collection_view self.kinematic_rate_collection = kinematic_rate_collection_view - def __iadd__(self, scaled_deriv_array): + def __iadd__(self, scaled_deriv_array: NDArray[np.floating]) -> Self: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -242,7 +267,7 @@ def __iadd__(self, scaled_deriv_array): ] return self - def __add__(self, scaled_derivative_state): + def __add__(self, scaled_derivative_state: NDArray[np.floating]) -> "_State": """overloaded + operator, useful in state.k1 = state + dt * deriv_state The add for directors is customized to reflect Rodrigues' rotation @@ -296,7 +321,9 @@ class _DerivativeState: /multiplication used. """ - def __init__(self, _unused_n_elems: int, rate_collection_view): + def __init__( + self, _unused_n_elems: int, rate_collection_view: NDArray[np.floating] + ) -> None: """ Parameters ---------- @@ -307,7 +334,7 @@ def __init__(self, _unused_n_elems: int, rate_collection_view): super(_DerivativeState, self).__init__() self.rate_collection = rate_collection_view - def __rmul__(self, scalar): + def __rmul__(self, scalar: np.floating) -> NDArray[np.floating]: """overloaded scalar * self, Parameters @@ -355,7 +382,7 @@ def __rmul__(self, scalar): """ return scalar * self.rate_collection - def __mul__(self, scalar): + def __mul__(self, scalar: np.floating) -> NDArray[np.floating]: """overloaded self * scalar TODO Check if this pattern (forwarding to __mul__) has @@ -388,7 +415,11 @@ class _KinematicState: only these methods are provided. """ - def __init__(self, position_collection_view, director_collection_view): + def __init__( + self, + position_collection_view: NDArray[np.floating], + director_collection_view: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -403,13 +434,13 @@ def __init__(self, position_collection_view, director_collection_view): @njit(cache=True) def overload_operator_kinematic_numba( - n_nodes, - prefac, - position_collection, - director_collection, - velocity_collection, - omega_collection, -): + n_nodes: int, + prefac: np.floating, + position_collection: NDArray[np.floating], + director_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], +) -> None: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -449,11 +480,11 @@ class _DynamicState: def __init__( self, - v_w_collection, - dvdt_dwdt_collection, - velocity_collection, - omega_collection, - ): + v_w_collection: NDArray[np.floating], + dvdt_dwdt_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -470,7 +501,9 @@ def __init__( self.velocity_collection = velocity_collection self.omega_collection = omega_collection - def kinematic_rates(self, time, prefac): + def kinematic_rates( + self, time: np.floating, prefac: np.floating + ) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """Yields kinematic rates to interact with _KinematicState Returns @@ -486,7 +519,9 @@ def kinematic_rates(self, time, prefac): # Comes from kin_state -> (x,Q) += dt * (v,w) <- First part of dyn_state return self.velocity_collection, self.omega_collection - def dynamic_rates(self, time, prefac): + def dynamic_rates( + self, time: np.floating, prefac: np.floating + ) -> NDArray[np.floating]: """Yields dynamic rates to add to with _DynamicState Returns ------- @@ -501,7 +536,10 @@ def dynamic_rates(self, time, prefac): @njit(cache=True) -def overload_operator_dynamic_numba(rate_collection, scaled_second_deriv_array): +def overload_operator_dynamic_numba( + rate_collection: NDArray[np.floating], + scaled_second_deriv_array: NDArray[np.floating], +) -> None: """overloaded += operator, updating dynamic_rates Parameters ---------- diff --git a/elastica/rod/factory_function.py b/elastica/rod/factory_function.py index 6eca6b5a..22cc33bc 100644 --- a/elastica/rod/factory_function.py +++ b/elastica/rod/factory_function.py @@ -1,30 +1,64 @@ __doc__ = """ Factory function to allocate variables for Cosserat Rod""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import logging import numpy as np from numpy.testing import assert_allclose +from numpy.typing import NDArray from elastica.utils import MaxDimension, Tolerance from elastica._linalg import _batch_cross, _batch_norm, _batch_dot def allocate( - n_elements, - direction, - normal, - base_length, - base_radius, - density, - youngs_modulus: float, + n_elements: int, + direction: NDArray[np.floating], + normal: NDArray[np.floating], + base_length: np.floating, + base_radius: np.floating, + density: np.floating, + youngs_modulus: np.floating, *, rod_origin_position: np.ndarray, ring_rod_flag: bool, - shear_modulus: Optional[float] = None, + shear_modulus: Optional[np.floating] = None, position: Optional[np.ndarray] = None, directors: Optional[np.ndarray] = None, rest_sigma: Optional[np.ndarray] = None, rest_kappa: Optional[np.ndarray] = None, - **kwargs, -): + **kwargs: Any, +) -> tuple[ + int, + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], + NDArray[np.floating], +]: log = logging.getLogger() if "poisson_ratio" in kwargs: @@ -335,14 +369,16 @@ def allocate( """ -def _assert_dim(vector, max_dim: int, name: str): +def _assert_dim(vector: np.ndarray, max_dim: int, name: str) -> None: assert vector.ndim < max_dim, ( f"Input {name} dimension is not correct {vector.shape}" + f" It should be maximum {max_dim}D vector or single floating number." ) -def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): +def _assert_shape( + array: np.ndarray, expected_shape: Tuple[int, ...], name: str +) -> None: assert array.shape == expected_shape, ( f"Given {name} shape is not correct, it should be " + str(expected_shape) @@ -351,7 +387,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): ) -def _position_validity_checker(position, start, n_elements): +def _position_validity_checker( + position: NDArray[np.floating], start: NDArray[np.floating], n_elements: int +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements + 1), "position") @@ -367,7 +405,9 @@ def _position_validity_checker(position, start, n_elements): ) -def _directors_validity_checker(directors, tangents, n_elements): +def _directors_validity_checker( + directors: NDArray[np.floating], tangents: NDArray[np.floating], n_elements: int +) -> None: """Checker on user-defined directors validity""" _assert_shape( directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors" @@ -413,7 +453,11 @@ def _directors_validity_checker(directors, tangents, n_elements): ) -def _position_validity_checker_ring_rod(position, ring_center_position, n_elements): +def _position_validity_checker_ring_rod( + position: NDArray[np.floating], + ring_center_position: NDArray[np.floating], + n_elements: int, +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements), "position") diff --git a/elastica/rod/knot_theory.py b/elastica/rod/knot_theory.py index 23d2b009..006b7527 100644 --- a/elastica/rod/knot_theory.py +++ b/elastica/rod/knot_theory.py @@ -14,6 +14,7 @@ from numba import njit import numpy as np +from numpy.typing import NDArray from elastica.rod.rod_base import RodBase from elastica._linalg import _batch_norm, _batch_dot, _batch_cross @@ -38,7 +39,7 @@ def radius(self) -> np.ndarray: ... def base_length(self) -> np.ndarray: ... -class KnotTheory: +class KnotTheory(KnotTheoryCompatibleProtocol): """ This mixin should be used in RodBase-derived class that satisfies KnotCompatibleProtocol. The theory behind this module is based on the method from Klenin & Langowski 2000 paper. @@ -46,7 +47,7 @@ class KnotTheory: KnotTheory can be mixed with any rod-class based on RodBase:: class MyRod(RodBase, KnotTheory): - def __init__(self): + def __init__(self) -> None: super().__init__() rod = MyRod(...) @@ -78,7 +79,7 @@ def __init__(self): MIXIN_PROTOCOL = Union[RodBase, KnotTheoryCompatibleProtocol] - def compute_twist(self: MIXIN_PROTOCOL): + def compute_twist(self: MIXIN_PROTOCOL) -> NDArray[np.floating]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. """ @@ -91,7 +92,7 @@ def compute_twist(self: MIXIN_PROTOCOL): def compute_writhe( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, + alpha: np.floating = 1.0, ): """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -114,8 +115,8 @@ def compute_writhe( def compute_link( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, - ): + alpha: np.floating = 1.0, + ) -> NDArray[np.floating]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -138,7 +139,9 @@ def compute_link( )[0] -def compute_twist(center_line, normal_collection): +def compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Compute the twist of a rod, using center_line and normal collection. @@ -189,7 +192,9 @@ def compute_twist(center_line, normal_collection): @njit(cache=True) -def _compute_twist(center_line, normal_collection): +def _compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Parameters ---------- @@ -264,7 +269,11 @@ def _compute_twist(center_line, normal_collection): return total_twist, local_twist -def compute_writhe(center_line, segment_length, type_of_additional_segment): +def compute_writhe( + center_line: NDArray[np.floating], + segment_length: np.floating, + type_of_additional_segment: str, +) -> NDArray[np.floating]: """ This function computes the total writhe history of a rod. @@ -314,7 +323,7 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment): @njit(cache=True) -def _compute_writhe(center_line): +def _compute_writhe(center_line: NDArray[np.floating]) -> NDArray[np.floating]: """ Parameters ---------- @@ -386,12 +395,12 @@ def _compute_writhe(center_line): def compute_link( - center_line: np.ndarray, - normal_collection: np.ndarray, - radius: np.ndarray, - segment_length: float, + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], + segment_length: np.floating, type_of_additional_segment: str, -): +) -> NDArray[np.floating]: """ This function computes the total link history of a rod. @@ -470,7 +479,11 @@ def compute_link( @njit(cache=True) -def _compute_auxiliary_line(center_line, normal_collection, radius): +def _compute_auxiliary_line( + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function computes the auxiliary line using rod center line and normal collection. @@ -525,7 +538,9 @@ def _compute_auxiliary_line(center_line, normal_collection, radius): @njit(cache=True) -def _compute_link(center_line, auxiliary_line): +def _compute_link( + center_line: NDArray[np.floating], auxiliary_line: NDArray[np.floating] +) -> NDArray[np.floating]: """ Parameters @@ -604,8 +619,11 @@ def _compute_link(center_line, auxiliary_line): @njit(cache=True) def _compute_auxiliary_line_added_segments( - beginning_direction, end_direction, auxiliary_line, segment_length -): + beginning_direction: NDArray[np.floating], + end_direction: NDArray[np.floating], + auxiliary_line: NDArray[np.floating], + segment_length: np.floating, +) -> NDArray[np.floating]: """ This code is for computing position of added segments to the auxiliary line. @@ -647,8 +665,10 @@ def _compute_auxiliary_line_added_segments( @njit(cache=True) def _compute_additional_segment( - center_line, segment_length, type_of_additional_segment -): + center_line: NDArray[np.floating], + segment_length: np.floating, + type_of_additional_segment: str, +) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: """ This function adds two points at the end of center line. Distance from the center line is given by segment_length. Direction from center line to the new point locations can be computed using 3 methods, which can be selected by diff --git a/elastica/rod/rod_base.py b/elastica/rod/rod_base.py index 314efe6d..0edc919e 100644 --- a/elastica/rod/rod_base.py +++ b/elastica/rod/rod_base.py @@ -1,5 +1,8 @@ __doc__ = """Base class for rods""" +from typing import Any +import numpy as np +from numpy.typing import NDArray class RodBase: """ @@ -11,10 +14,15 @@ class RodBase: """ - REQUISITE_MODULES = [] + REQUISITE_MODULES: list[Any] = [] - def __init__(self): + def __init__(self) -> None: """ RodBase does not take any arguments. """ - pass + self.position_collection: NDArray[np.floating] + self.omega_collection: NDArray[np.floating] + self.acceleration_collection: NDArray[np.floating] + self.alpha_collection: NDArray[np.floating] + self.external_forces: NDArray[np.floating] + self.external_torques: NDArray[np.floating] diff --git a/elastica/transformations.py b/elastica/transformations.py index b002e956..c8e97721 100644 --- a/elastica/transformations.py +++ b/elastica/transformations.py @@ -10,11 +10,14 @@ from .utils import MaxDimension, isqrt +from numpy.typing import NDArray # TODO Complete, but nicer interface, evolve it eventually -def format_vector_shape(vector_collection): +def format_vector_shape( + vector_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Function for formatting vector shapes into correct format @@ -59,7 +62,9 @@ def format_vector_shape(vector_collection): return vector_collection -def format_matrix_shape(matrix_collection): +def format_matrix_shape( + matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Formats input matrix into correct format @@ -77,7 +82,7 @@ def format_matrix_shape(matrix_collection): # check first two dimensions are same and matrix is square # other possibility is one dimension is dim**2 and other is blocksize, # we need to convert the matrix in that case. - def assert_proper_square(num1): + def assert_proper_square(num1: int) -> int: sqrt_num = isqrt(num1) assert sqrt_num**2 == num1, "Matrix dimension passed is not a perfect square" return sqrt_num @@ -136,12 +141,14 @@ def assert_proper_square(num1): return matrix_collection -def skew_symmetrize(vector): +def skew_symmetrize(vector: NDArray[np.floating]) -> NDArray[np.floating]: vector = format_vector_shape(vector) return _skew_symmetrize(vector) -def inv_skew_symmetrize(matrix_collection): +def inv_skew_symmetrize( + matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Safe wrapper around inv_skew_symmetrize that does checking and formatting on type of matrix_collection using format_matrix_shape @@ -167,7 +174,9 @@ def inv_skew_symmetrize(matrix_collection): raise ValueError("matrix_collection passed is not skew-symmetric") -def rotate(matrix, scale, axis): +def rotate( + matrix: NDArray[np.floating], scale: np.floating, axis: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function takes single or multiple frames as matrix. Then rotates these frames around a single axis for all frames, or can rotate each frame around its own diff --git a/elastica/typing.py b/elastica/typing.py index 6599f1af..fd1337fe 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -70,7 +70,7 @@ # [SymplecticStepperProtocol, np.floating], np.floating # ] OperatorType: TypeAlias = Callable[ - Any, Any + ..., Any ] # TODO: Maybe can be more specific. Up for discussion. SteppersOperatorsType: TypeAlias = tuple[tuple[OperatorType, ...], ...] # tuple[Union[PrefactorOperatorType, StepOperatorType, NoOpType, np.floating], ...], ... diff --git a/elastica/utils.py b/elastica/utils.py index bbf9baa2..81517d54 100644 --- a/elastica/utils.py +++ b/elastica/utils.py @@ -1,12 +1,15 @@ """ Handy utilities """ +from typing import Generator, Iterable, Any, Literal, TypeVar import functools import numpy as np from numpy import finfo, float64 from itertools import islice from scipy.interpolate import BSpline +from numpy.typing import NDArray + # Slower than the python3.8 isqrt implementation for small ints # python isqrt : ~130 ns @@ -47,6 +50,8 @@ def isqrt(num: int) -> int: elif num == 0: return 0 + raise ValueError("num must be a positive number") + class MaxDimension: """ @@ -54,7 +59,7 @@ class MaxDimension: """ @staticmethod - def value(): + def value() -> Literal[3]: """ Returns spatial dimension @@ -67,7 +72,7 @@ def value(): class Tolerance: @staticmethod - def atol(): + def atol() -> np.floating: """ Static absolute tolerance method @@ -78,7 +83,7 @@ def atol(): return finfo(float64).eps * 1e4 @staticmethod - def rtol(): + def rtol() -> np.floating: """ Static relative tolerance method @@ -89,7 +94,7 @@ def rtol(): return finfo(float64).eps * 1e11 -def perm_parity(lst): +def perm_parity(lst: list[int]) -> int: """ Given a permutation of the digits 0..N in order as a list, returns its parity (or sign): +1 for even parity; -1 for odd. @@ -115,7 +120,10 @@ def perm_parity(lst): return parity -def grouper(iterable, n): +_T = TypeVar("_T") + + +def grouper(iterable: Iterable[_T], n: int) -> Generator[tuple[_T, ...], None, None]: """Collect data into fixed-length chunks or blocks" Parameters @@ -144,7 +152,7 @@ def grouper(iterable, n): yield group -def extend_instance(obj, cls): +def extend_instance(obj: Any, cls: Any) -> None: """ Apply mixins to a class instance after creation @@ -170,7 +178,9 @@ def extend_instance(obj, cls): obj.__class__ = type(base_cls_name, (cls, base_cls), {}) -def _bspline(t_coeff, l_centerline=1.0): +def _bspline( + t_coeff: NDArray, l_centerline: np.floating = 1.0 +) -> tuple[BSpline, NDArray, NDArray]: """Generates a bspline object that plots the spline interpolant for any vector x. Optionally takes in a centerline length, set to 1.0 by default and keep_pts for keeping record of control points @@ -198,7 +208,9 @@ def _bspline(t_coeff, l_centerline=1.0): return __bspline_impl__(control_pts, t_coeff, degree) -def __bspline_impl__(x_pts, t_c, degree): +def __bspline_impl__( + x_pts: NDArray, t_c: NDArray, degree: int +) -> tuple[BSpline, NDArray, NDArray]: """""" # Update the knots