From 8ff2c891882b0df79b130c6bb88953ab6ffc56a3 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 12:05:31 +0200 Subject: [PATCH 01/40] Exchange List with Sequence in args --- src/ptwt/_stationary_transform.py | 5 +++-- src/ptwt/_util.py | 15 ++++++++------- src/ptwt/conv_transform.py | 14 +++++++------- src/ptwt/conv_transform_2.py | 7 ++++--- src/ptwt/conv_transform_3.py | 6 +++--- src/ptwt/matmul_transform.py | 5 +++-- src/ptwt/matmul_transform_2.py | 3 ++- src/ptwt/matmul_transform_3.py | 5 +++-- src/ptwt/packets.py | 3 ++- src/ptwt/separable_conv_transform.py | 13 +++++++------ 10 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index c1fe5484..d538a01b 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -1,5 +1,6 @@ """This module implements stationary wavelet transforms.""" +from collections.abc import Sequence from typing import List, Optional, Union import pywt @@ -107,14 +108,14 @@ def _conv_transpose_dedilate( def _iswt( - coeffs: List[torch.Tensor], + coeffs: Sequence[torch.Tensor], wavelet: Union[pywt.Wavelet, str], axis: Optional[int] = -1, ) -> torch.Tensor: """Inverts a 1d stationary wavelet transform. Args: - coeffs (List[torch.Tensor]): The coefficients as computed by the swt function. + coeffs (Sequence[torch.Tensor]): The coefficients as computed by the swt function. wavelet (Union[pywt.Wavelet, str]): The wavelet used for the forward transform. axis (int, optional): The axis the forward trasform was computed over. Defaults to -1. diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index b0875dbe..73214f62 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,7 +1,8 @@ """Utility methods to compute wavelet decompositions from a dataset.""" +from collections.abc import Sequence import typing -from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Protocol, Tuple, Union import numpy as np import pywt @@ -92,7 +93,7 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch. def _pad_symmetric( - signal: torch.Tensor, pad_lists: List[Tuple[int, int]] + signal: torch.Tensor, pad_lists: Sequence[Tuple[int, int]] ) -> torch.Tensor: if len(signal.shape) < len(pad_lists): raise ValueError("not enough dimensions to pad.") @@ -137,13 +138,13 @@ def _check_if_tensor(array: Any) -> torch.Tensor: return array -def _check_axes_argument(axes: List[int]) -> None: +def _check_axes_argument(axes: Sequence[int]) -> None: if len(set(axes)) != len(axes): raise ValueError("Cant transform the same axis twice.") def _get_transpose_order( - axes: List[int], data_shape: List[int] + axes: Sequence[int], data_shape: Sequence[int] ) -> Tuple[List[int], List[int]]: axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes)) all_axes = list(range(len(data_shape))) @@ -151,13 +152,13 @@ def _get_transpose_order( return remove_transformed, axes -def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: +def _swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: _check_axes_argument(axes) front, back = _get_transpose_order(axes, list(data.shape)) return torch.permute(data, front + back) -def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: +def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: _check_axes_argument(axes) front, back = _get_transpose_order(axes, list(data.shape)) restore_sorted = torch.argsort(torch.tensor(front + back)).tolist() @@ -165,7 +166,7 @@ def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: def _map_result( - data: List[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any + data: Sequence[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any function: Callable[[Any], torch.Tensor], ) -> List[Union[torch.Tensor, Any]]: # Apply the given function to the input list of tensor and tuples. diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index a90e1ed6..24500532 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -170,15 +170,15 @@ def _fwt_pad( def _flatten_2d_coeff_lst( - coeff_lst_2d: List[ + coeff_lst_2d: Sequence[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], flatten_tensors: bool = True, ) -> List[torch.Tensor]: - """Flattens a list of tensor tuples into a single list. + """Flattens a sequence of tensor tuples into a single list. Args: - coeff_lst_2d (list): A pywt-style coefficient list of torch tensors. + coeff_lst_2d (Sequence): A pywt-style coefficient sequence of torch tensors. flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True. Returns: @@ -243,14 +243,14 @@ def _preprocess_tensor_dec1d( def _postprocess_result_list_dec1d( - result_lst: List[torch.Tensor], ds: List[int] + result_lst: Sequence[torch.Tensor], ds: List[int] ) -> List[torch.Tensor]: # Unfold axes for the wavelets return [_unfold_axes(fres, ds, 1) for fres in result_lst] def _preprocess_result_list_rec1d( - result_lst: List[torch.Tensor], + result_lst: Sequence[torch.Tensor], ) -> Tuple[List[torch.Tensor], List[int]]: # Fold axes for the wavelets ds = list(result_lst[0].shape) @@ -360,12 +360,12 @@ def wavedec( def waverec( - coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1 + coeffs: Sequence[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1 ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. Args: - coeffs (list): The wavelet coefficient list produced by wavedec. + coeffs (Sequence): The wavelet coefficient sequence produced by wavedec. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axis (int): Transform this axis instead of the last one. Defaults to -1. diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index ca987cb7..496d38bc 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -4,6 +4,7 @@ torch.nn.functional.conv_transpose2d under the hood. """ +from collections.abc import Sequence from functools import partial from typing import List, Optional, Tuple, Union, cast @@ -96,7 +97,7 @@ def _fwt_pad2( def _waverec2d_fold_channels_2d_list( - coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ) -> Tuple[ List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], List[int], @@ -239,7 +240,7 @@ def wavedec2( def waverec2( - coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], wavelet: Union[Wavelet, str], axes: Tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -249,7 +250,7 @@ def waverec2( or forward transform by running transposed convolutions. Args: - coeffs (list): The wavelet coefficient list produced by wavedec2. + coeffs (sequence): The wavelet coefficient sequence produced by wavedec2. The coefficients must be in pywt order. That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 38341cb9..12687d23 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -210,7 +210,7 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( - coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], ) -> Tuple[ List[Union[torch.Tensor, Dict[str, torch.Tensor]]], List[int], @@ -230,14 +230,14 @@ def _waverec3d_fold_channels_3d_list( def waverec3( - coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[Wavelet, str], axes: Tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. Args: - coeffs (list): The wavelet coefficient list produced by wavedec3. + coeffs (sequence): The wavelet coefficient sequence produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axes (Tuple[int, int, int]): Transform these axes instead of the diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index aa2af955..9204d674 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -8,6 +8,7 @@ """ import sys +from collections import Sequence from typing import List, Optional, Union import numpy as np @@ -595,11 +596,11 @@ def _construct_synthesis_matrices( self.ifwt_matrix_list.append(sn) curr_length = curr_length // 2 - def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor: + def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: """Run the synthesis or inverse matrix fwt. Args: - coefficients (List[torch.Tensor]): The coefficients produced by the forward + coefficients (Sequence[torch.Tensor]): The coefficients produced by the forward transform. Returns: diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 215e738a..2c86e13d 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -4,6 +4,7 @@ """ import sys +from collections.abc import Sequence from functools import partial from typing import List, Optional, Tuple, Union, cast @@ -726,7 +727,7 @@ def _construct_synthesis_matrices( def __call__( self, - coefficients: List[ + coefficients: Sequence[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], ) -> torch.Tensor: diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 4c2ab65f..1bf10058 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -1,6 +1,7 @@ """Implement 3D separable boundary transforms.""" import sys +from collections.abc import Sequence from functools import partial from typing import Dict, List, NamedTuple, Optional, Tuple, Union @@ -386,12 +387,12 @@ def _cat_coeff_recursive(self, input_dict: Dict[str, torch.Tensor]) -> torch.Ten return self._cat_coeff_recursive(done_dict) def __call__( - self, coefficients: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] + self, coefficients: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]] ) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: - coefficients (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coefficients (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The output from MatrixWavedec3. Returns: diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 95e653b7..a522741b 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -1,6 +1,7 @@ """Compute analysis wavelet packet representations.""" import collections +from collections.abc import Sequence from functools import partial from itertools import product from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, cast @@ -465,7 +466,7 @@ def _transform_tuple_to_fsdict_func( torch.Tensor, ]: def _fsdict_func( - coeffs: List[ + coeffs: Sequence[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] ) -> torch.Tensor: diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 91106f6c..9be3589d 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -7,6 +7,7 @@ using torch.nn.functional.conv1d and it's transpose. """ +from collections.abc import Sequence from functools import partial from typing import Dict, List, Optional, Tuple, Union @@ -143,13 +144,13 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( - coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: pywt.Wavelet, ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: - coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The output as produced by `_separable_conv_wavedecn`. wavelet (pywt.Wavelet): The wavelet used by `_separable_conv_wavedecn`. @@ -315,7 +316,7 @@ def fswavedec3( def fswaverec2( - coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], axes: Tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -325,7 +326,7 @@ def fswaverec2( the hood. Args: - coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. @@ -382,14 +383,14 @@ def fswaverec2( def fswaverec3( - coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], axes: Tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. From aacca4b58fb1f2231e85ef0fce975e7146137700 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 12:35:01 +0200 Subject: [PATCH 02/40] Use builtin list instead of List --- src/ptwt/_stationary_transform.py | 6 ++--- src/ptwt/_util.py | 14 ++++++------ src/ptwt/conv_transform.py | 16 ++++++------- src/ptwt/conv_transform_2.py | 12 +++++----- src/ptwt/conv_transform_3.py | 12 +++++----- src/ptwt/matmul_transform.py | 14 ++++++------ src/ptwt/matmul_transform_2.py | 16 ++++++------- src/ptwt/matmul_transform_3.py | 12 +++++----- src/ptwt/packets.py | 34 ++++++++++++++-------------- src/ptwt/separable_conv_transform.py | 16 ++++++------- src/ptwt/sparse_math.py | 3 +-- 11 files changed, 77 insertions(+), 78 deletions(-) diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index d538a01b..e188dd52 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -1,7 +1,7 @@ """This module implements stationary wavelet transforms.""" from collections.abc import Sequence -from typing import List, Optional, Union +from typing import Optional, Union import pywt import torch @@ -20,7 +20,7 @@ def _swt( wavelet: Union[Wavelet, str], level: Optional[int] = None, axis: Optional[int] = -1, -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: """Compute a multilevel 1d stationary wavelet transform. Args: @@ -29,7 +29,7 @@ def _swt( level (Optional[int], optional): The number of levels to compute Returns: - List[torch.Tensor]: Same as wavedec. + list[torch.Tensor]: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. Raises: diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 73214f62..87837ced 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -2,7 +2,7 @@ from collections.abc import Sequence import typing -from typing import Any, Callable, List, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Optional, Protocol, Tuple, Union import numpy as np import pywt @@ -107,7 +107,7 @@ def _pad_symmetric( return signal -def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int]]: +def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, list[int]]: """Fold unchanged leading dimensions into a single batch dimension. Args: @@ -115,7 +115,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int keep_no (int): The number of dimensions to keep. Returns: - Tuple[ torch.Tensor, List[int]]: + Tuple[torch.Tensor, list[int]]: The folded result array, and the shape of the original input. """ dshape = list(data.shape) @@ -125,7 +125,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int ) -def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor: +def _unfold_axes(data: torch.Tensor, ds: list[int], keep_no: int) -> torch.Tensor: """Unfold i.e. [batch*channel,height,widht] to [batch,channel,height,width].""" return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:])) @@ -145,7 +145,7 @@ def _check_axes_argument(axes: Sequence[int]) -> None: def _get_transpose_order( axes: Sequence[int], data_shape: Sequence[int] -) -> Tuple[List[int], List[int]]: +) -> Tuple[list[int], list[int]]: axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes)) all_axes = list(range(len(data_shape))) remove_transformed = list(filter(lambda a: a not in axes, all_axes)) @@ -168,9 +168,9 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: def _map_result( data: Sequence[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any function: Callable[[Any], torch.Tensor], -) -> List[Union[torch.Tensor, Any]]: +) -> list[Union[torch.Tensor, Any]]: # Apply the given function to the input list of tensor and tuples. - result_lst: List[Union[torch.Tensor, Any]] = [] + result_lst: list[Union[torch.Tensor, Any]] = [] for element in data: if isinstance(element, torch.Tensor): result_lst.append(function(element)) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 24500532..43d9308e 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -3,7 +3,7 @@ This module treats boundaries with edge-padding. """ -from typing import List, Optional, Sequence, Tuple, Union, cast +from typing import Optional, Sequence, Tuple, Union, cast import pywt import torch @@ -174,7 +174,7 @@ def _flatten_2d_coeff_lst( Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], flatten_tensors: bool = True, -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: """Flattens a sequence of tensor tuples into a single list. Args: @@ -218,14 +218,14 @@ def _adjust_padding_at_reconstruction( def _preprocess_tensor_dec1d( data: torch.Tensor, -) -> Tuple[torch.Tensor, Union[List[int], None]]: +) -> Tuple[torch.Tensor, Union[list[int], None]]: """Preprocess input tensor dimensions. Args: data (torch.Tensor): An input tensor of any shape. Returns: - Tuple[torch.Tensor, Union[List[int], None]]: + Tuple[torch.Tensor, Union[list[int], None]]: A data tensor of shape [new_batch, 1, to_process] and the original shape, if the shape has changed. """ @@ -243,15 +243,15 @@ def _preprocess_tensor_dec1d( def _postprocess_result_list_dec1d( - result_lst: Sequence[torch.Tensor], ds: List[int] -) -> List[torch.Tensor]: + result_lst: Sequence[torch.Tensor], ds: list[int] +) -> list[torch.Tensor]: # Unfold axes for the wavelets return [_unfold_axes(fres, ds, 1) for fres in result_lst] def _preprocess_result_list_rec1d( result_lst: Sequence[torch.Tensor], -) -> Tuple[List[torch.Tensor], List[int]]: +) -> Tuple[list[torch.Tensor], list[int]]: # Fold axes for the wavelets ds = list(result_lst[0].shape) fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] @@ -265,7 +265,7 @@ def wavedec( mode: BoundaryMode = "reflect", level: Optional[int] = None, axis: int = -1, -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: r"""Compute the analysis (forward) 1d fast wavelet transform. The transformation relies on convolution operations with filter diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 496d38bc..f66aed06 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from functools import partial -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union, cast import pywt import torch @@ -99,8 +99,8 @@ def _fwt_pad2( def _waverec2d_fold_channels_2d_list( coeffs: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ) -> Tuple[ - List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], - List[int], + list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + list[int], ]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) @@ -109,7 +109,7 @@ def _waverec2d_fold_channels_2d_list( def _preprocess_tensor_dec2d( data: torch.Tensor, -) -> Tuple[torch.Tensor, Union[List[int], None]]: +) -> Tuple[torch.Tensor, Union[list[int], None]]: # Preprocess multidimensional input. ds = None if len(data.shape) == 2: @@ -132,7 +132,7 @@ def wavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: Tuple[int, int] = (-2, -1), -) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: +) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: r"""Run a two-dimensional wavelet transformation. This function relies on two-dimensional convolutions. @@ -215,7 +215,7 @@ def wavedec2( if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) - result_lst: List[ + result_lst: list[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = [] res_ll = data diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 12687d23..b94c2277 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -4,7 +4,7 @@ """ from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Dict, Optional, Sequence, Tuple, Union, cast import pywt import torch @@ -108,7 +108,7 @@ def wavedec3( mode: BoundaryMode = "zero", level: Optional[int] = None, axes: Tuple[int, int, int] = (-3, -2, -1), -) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a three-dimensional wavelet transform. Args: @@ -174,7 +174,7 @@ def wavedec3( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) - result_lst: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + result_lst: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: @@ -212,11 +212,11 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], ) -> Tuple[ - List[Union[torch.Tensor, Dict[str, torch.Tensor]]], - List[int], + list[Union[torch.Tensor, Dict[str, torch.Tensor]]], + list[int], ]: # fold the input coefficients for processing conv2d_transpose. - fold_coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + fold_coeffs: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) for coeff in coeffs: if isinstance(coeff, torch.Tensor): diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 9204d674..6bf12f33 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -9,7 +9,7 @@ import sys from collections import Sequence -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np import torch @@ -219,10 +219,10 @@ def __init__( raise ValueError("MatrixWavedec transforms a single axis only.") self.input_length: Optional[int] = None - self.fwt_matrix_list: List[torch.Tensor] = [] - self.pad_list: List[bool] = [] + self.fwt_matrix_list: list[torch.Tensor] = [] + self.pad_list: list[bool] = [] self.padded = False - self.size_list: List[int] = [] + self.size_list: list[int] = [] if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError @@ -316,7 +316,7 @@ def _construct_analysis_matrices( self.size_list.append(curr_length) - def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: + def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: """Compute the matrix fwt for the given input signal. Matrix FWTs are used to avoid padding. @@ -330,7 +330,7 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: another axis. Returns: - List[torch.Tensor]: A list with the coefficients for each scale. + list[torch.Tensor]: A list with the coefficients for each scale. Raises: ValueError: If the decomposition level is not a positive integer @@ -501,7 +501,7 @@ def __init__( else: raise ValueError("MatrixWaverec transforms a single axis only.") - self.ifwt_matrix_list: List[torch.Tensor] = [] + self.ifwt_matrix_list: list[torch.Tensor] = [] self.level: Optional[int] = None self.input_length: Optional[int] = None self.padded = False diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 2c86e13d..a4a65050 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -6,7 +6,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union, cast import numpy as np import torch @@ -298,10 +298,10 @@ def __init__( self.boundary = boundary self.separable = separable self.input_signal_shape: Optional[Tuple[int, int]] = None - self.fwt_matrix_list: List[ + self.fwt_matrix_list: list[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = [] - self.pad_list: List[Tuple[bool, bool]] = [] + self.pad_list: list[Tuple[bool, bool]] = [] self.padded = False if not _is_boundary_mode_supported(self.boundary): @@ -330,7 +330,7 @@ def sparse_fwt_operator(self) -> torch.Tensor: raise NotImplementedError # in the non-separable case the list entries are tensors - fwt_matrix_list = cast(List[torch.Tensor], self.fwt_matrix_list) + fwt_matrix_list = cast(list[torch.Tensor], self.fwt_matrix_list) if len(fwt_matrix_list) == 1: return fwt_matrix_list[0] @@ -417,7 +417,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call @@ -476,7 +476,7 @@ def __call__( device=input_signal.device, dtype=input_signal.dtype ) - split_list: List[ + split_list: list[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = [] if self.separable: @@ -613,7 +613,7 @@ def __init__( _check_axes_argument(list(axes)) self.axes = axes - self.ifwt_matrix_list: List[ + self.ifwt_matrix_list: list[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = [] self.level: Optional[int] = None @@ -644,7 +644,7 @@ def sparse_ifwt_operator(self) -> torch.Tensor: raise NotImplementedError # in the non-separable case the list entries are tensors - ifwt_matrix_list = cast(List[torch.Tensor], self.ifwt_matrix_list) + ifwt_matrix_list = cast(list[torch.Tensor], self.ifwt_matrix_list) if len(ifwt_matrix_list) == 1: return ifwt_matrix_list[0] diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 1bf10058..5f1d9de8 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -3,7 +3,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Dict, NamedTuple, Optional, Tuple, Union import numpy as np import torch @@ -88,7 +88,7 @@ def __init__( _check_axes_argument(list(axes)) self.axes = axes self.input_signal_shape: Optional[Tuple[int, int, int]] = None - self.fwt_matrix_list: List[List[torch.Tensor]] = [] + self.fwt_matrix_list: list[list[torch.Tensor]] = [] if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError @@ -158,7 +158,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a separable 3d-boundary wavelet transform. Args: @@ -169,7 +169,7 @@ def __call__( ValueError: If the input dimensions don't work. Returns: - List[Union[torch.Tensor, TypedDict[str, torch.Tensor]]]: + list[Union[torch.Tensor, TypedDict[str, torch.Tensor]]]: A list with the approximation coefficients, and a coefficient dict for each scale. """ @@ -219,7 +219,7 @@ def __call__( device=input_signal.device, dtype=input_signal.dtype ) - split_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + split_list: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): # fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats @@ -305,7 +305,7 @@ def __init__( _check_axes_argument(list(axes)) self.axes = axes self.boundary = boundary - self.ifwt_matrix_list: List[List[torch.Tensor]] = [] + self.ifwt_matrix_list: list[list[torch.Tensor]] = [] self.input_signal_shape: Optional[Tuple[int, int, int]] = None self.level: Optional[int] = None diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index a522741b..dbe5364a 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from functools import partial from itertools import product -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union, cast import numpy as np import pywt @@ -24,7 +24,7 @@ BaseDict = collections.UserDict -def _wpfreq(fs: float, level: int) -> List[float]: +def _wpfreq(fs: float, level: int) -> list[float]: """Compute the frequencies for a fully decomposed 1d packet tree. The packet transform linearly subdivides all frequencies @@ -35,7 +35,7 @@ def _wpfreq(fs: float, level: int) -> List[float]: level (int): The decomposition level. Returns: - List[float]: The frequency bins of the packets in frequency order. + list[float]: The frequency bins of the packets in frequency order. """ n = np.array(range(int(np.power(2.0, level)))) freqs = (fs / 2.0) * (n / (np.power(2.0, level))) @@ -168,7 +168,7 @@ def reconstruct(self) -> "WaveletPacket": def _get_wavedec( self, length: int, - ) -> Callable[[torch.Tensor], List[torch.Tensor]]: + ) -> Callable[[torch.Tensor], list[torch.Tensor]]: if self.mode == "boundary": if length not in self._matrix_wavedec_dict.keys(): self._matrix_wavedec_dict[length] = MatrixWavedec( @@ -183,7 +183,7 @@ def _get_wavedec( def _get_waverec( self, length: int, - ) -> Callable[[List[torch.Tensor]], torch.Tensor]: + ) -> Callable[[list[torch.Tensor]], torch.Tensor]: if self.mode == "boundary": if length not in self._matrix_waverec_dict.keys(): self._matrix_waverec_dict[length] = MatrixWaverec( @@ -193,7 +193,7 @@ def _get_waverec( else: return partial(waverec, wavelet=self.wavelet, axis=self.axis) - def get_level(self, level: int) -> List[str]: + def get_level(self, level: int) -> list[str]: """Return the graycode-ordered paths to the filter tree nodes. Args: @@ -204,7 +204,7 @@ def get_level(self, level: int) -> List[str]: """ return self._get_graycode_order(level) - def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[str]: + def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]: graycode_order = [x, y] for _ in range(level - 1): graycode_order = [x + path for path in graycode_order] + [ @@ -372,7 +372,7 @@ def reconstruct(self) -> "WaveletPacket2D": self[node] = rec return self - def get_natural_order(self, level: int) -> List[str]: + def get_natural_order(self, level: int) -> list[str]: """Get the natural ordering for a given decomposition level. Args: @@ -385,7 +385,7 @@ def get_natural_order(self, level: int) -> List[str]: def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[ [torch.Tensor], - List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ]: if self.mode == "boundary": shape = tuple(shape) @@ -415,7 +415,7 @@ def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[ ) def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[ - [List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], + [list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], torch.Tensor, ]: if self.mode == "boundary": @@ -438,15 +438,15 @@ def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[ def _transform_fsdict_to_tuple_func( self, fs_dict_func: Callable[ - [torch.Tensor], List[Union[torch.Tensor, Dict[str, torch.Tensor]]] + [torch.Tensor], list[Union[torch.Tensor, Dict[str, torch.Tensor]]] ], ) -> Callable[ [torch.Tensor], - List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ]: def _tuple_func( data: torch.Tensor, - ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: a_coeff, fsdict = fs_dict_func(data) fsdict = cast(Dict[str, torch.Tensor], fsdict) return [ @@ -459,10 +459,10 @@ def _tuple_func( def _transform_tuple_to_fsdict_func( self, fsdict_func: Callable[ - [List[Union[torch.Tensor, Dict[str, torch.Tensor]]]], torch.Tensor + [list[Union[torch.Tensor, Dict[str, torch.Tensor]]]], torch.Tensor ], ) -> Callable[ - [List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], + [list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], torch.Tensor, ]: def _fsdict_func( @@ -520,7 +520,7 @@ def __getitem__(self, key: str) -> torch.Tensor: return super().__getitem__(key) -def get_freq_order(level: int) -> List[List[Tuple[str, ...]]]: +def get_freq_order(level: int) -> list[list[Tuple[str, ...]]]: """Get the frequency order for a given packet decomposition level. Use this code to create two-dimensional frequency orderings. @@ -544,7 +544,7 @@ def get_freq_order(level: int) -> List[List[Tuple[str, ...]]]: """ wp_natural_path = product(["a", "h", "v", "d"], repeat=level) - def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> List[str]: + def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: graycode_order = [x, y] for _ in range(level - 1): graycode_order = [x + path for path in graycode_order] + [ diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 9be3589d..12ad9de0 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -9,7 +9,7 @@ from collections.abc import Sequence from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import pywt @@ -112,7 +112,7 @@ def _separable_conv_wavedecn( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, -) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a multilevel separable padded wavelet analysis transform. Args: @@ -122,9 +122,9 @@ def _separable_conv_wavedecn( level (int): The desired decomposition level. Returns: - List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: The wavelet coeffs. + list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: The wavelet coeffs. """ - result: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + result: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] approx = input if level is None: @@ -181,7 +181,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: Tuple[int, int] = (-2, -1), -) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -202,7 +202,7 @@ def fswavedec2( ValueError: If the data is not a batched 2D signal. Returns: - List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: @@ -251,7 +251,7 @@ def fswavedec3( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: Tuple[int, int, int] = (-3, -2, -1), -) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 3D-padded analysis wavelet transform. Args: @@ -271,7 +271,7 @@ def fswavedec3( ValueError: If the input is not a batched 3D signal. Returns: - List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: diff --git a/src/ptwt/sparse_math.py b/src/ptwt/sparse_math.py index d9cf9e6a..f7721cb6 100644 --- a/src/ptwt/sparse_math.py +++ b/src/ptwt/sparse_math.py @@ -1,7 +1,6 @@ """Efficiently construct fwt operations using sparse matrices.""" from itertools import product -from typing import List import torch @@ -308,7 +307,7 @@ def _orth_by_gram_schmidt( Returns: torch.Tensor: The orthogonalized sparse matrix. """ - done: List[int] = [] + done: list[int] = [] # loop over the rows we want to orthogonalize for row_no_to_ortho in to_orthogonalize: current_row = matrix.select(0, row_no_to_ortho).unsqueeze(0) From fa147976f7a7a64a034c1a29d2143e630b4825ce Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 12:42:53 +0200 Subject: [PATCH 03/40] Use builtin tuple instead of Tuple --- src/ptwt/_util.py | 16 +++++++-------- src/ptwt/continuous_transform.py | 8 ++++---- src/ptwt/conv_transform.py | 16 +++++++-------- src/ptwt/conv_transform_2.py | 24 +++++++++++----------- src/ptwt/conv_transform_3.py | 12 +++++------ src/ptwt/matmul_transform_2.py | 26 ++++++++++++------------ src/ptwt/matmul_transform_3.py | 14 ++++++------- src/ptwt/packets.py | 30 ++++++++++++++-------------- src/ptwt/separable_conv_transform.py | 16 +++++++-------- src/ptwt/wavelets_learnable.py | 11 +++++----- 10 files changed, 86 insertions(+), 87 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 87837ced..e3cb61f1 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -2,7 +2,7 @@ from collections.abc import Sequence import typing -from typing import Any, Callable, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Optional, Protocol, Union import numpy as np import pywt @@ -21,7 +21,7 @@ class Wavelet(Protocol): rec_hi: Sequence[float] dec_len: int rec_len: int - filter_bank: Tuple[ + filter_bank: tuple[ Sequence[float], Sequence[float], Sequence[float], Sequence[float] ] @@ -64,7 +64,7 @@ def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a_mul * b_mul -def _get_len(wavelet: Union[Tuple[torch.Tensor, ...], str, Wavelet]) -> int: +def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int: """Get number of filter coefficients for various wavelet data types.""" if isinstance(wavelet, tuple): return wavelet[0].shape[0] @@ -72,7 +72,7 @@ def _get_len(wavelet: Union[Tuple[torch.Tensor, ...], str, Wavelet]) -> int: return len(_as_wavelet(wavelet)) -def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch.Tensor: +def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.Tensor: padl, padr = pad_list dimlen = signal.shape[0] if padl > dimlen or padr > dimlen: @@ -93,7 +93,7 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch. def _pad_symmetric( - signal: torch.Tensor, pad_lists: Sequence[Tuple[int, int]] + signal: torch.Tensor, pad_lists: Sequence[tuple[int, int]] ) -> torch.Tensor: if len(signal.shape) < len(pad_lists): raise ValueError("not enough dimensions to pad.") @@ -107,7 +107,7 @@ def _pad_symmetric( return signal -def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, list[int]]: +def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int]]: """Fold unchanged leading dimensions into a single batch dimension. Args: @@ -115,7 +115,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, list[int keep_no (int): The number of dimensions to keep. Returns: - Tuple[torch.Tensor, list[int]]: + tuple[torch.Tensor, list[int]]: The folded result array, and the shape of the original input. """ dshape = list(data.shape) @@ -145,7 +145,7 @@ def _check_axes_argument(axes: Sequence[int]) -> None: def _get_transpose_order( axes: Sequence[int], data_shape: Sequence[int] -) -> Tuple[list[int], list[int]]: +) -> tuple[list[int], list[int]]: axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes)) all_axes = list(range(len(data_shape))) remove_transformed = list(filter(lambda a: a not in axes, all_axes)) diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py index deb863f1..e2fb8f89 100644 --- a/src/ptwt/continuous_transform.py +++ b/src/ptwt/continuous_transform.py @@ -3,7 +3,7 @@ This module is based on pywt's cwt implementation. """ -from typing import Any, Tuple, Union +from typing import Any, Union import numpy as np import torch @@ -27,7 +27,7 @@ def cwt( scales: Union[np.ndarray, torch.Tensor], # type: ignore wavelet: Union[ContinuousWavelet, str], sampling_period: float = 1.0, -) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore +) -> tuple[torch.Tensor, np.ndarray]: # type: ignore """Compute the single-dimensional continuous wavelet transform. This function is a PyTorch port of pywt.cwt as found at: @@ -50,7 +50,7 @@ def cwt( ValueError: If a scale is too small for the input signal. Returns: - Tuple[torch.Tensor, np.ndarray]: The first tuple-element contains + tuple[torch.Tensor, np.ndarray]: The first tuple-element contains the transformation matrix of shape [scales, batch, time]. The second element contains an array with frequency information. @@ -267,7 +267,7 @@ def center(self) -> torch.Tensor: def wavefun( self, precision: int, dtype: torch.dtype = torch.float64 - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Define a grid and evaluate the wavelet on it.""" length = 2**precision # load the bounds from untyped pywt code. diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 43d9308e..838bc039 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -3,7 +3,7 @@ This module treats boundaries with edge-padding. """ -from typing import Optional, Sequence, Tuple, Union, cast +from typing import Optional, Sequence, Union, cast import pywt import torch @@ -40,7 +40,7 @@ def _get_filter_tensors( flip: bool, device: Union[torch.device, str], dtype: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Convert input wavelet to filter tensors. Args: @@ -70,7 +70,7 @@ def _get_filter_tensors( return dec_lo_tensor, dec_hi_tensor, rec_lo_tensor, rec_hi_tensor -def _get_pad(data_len: int, filt_len: int) -> Tuple[int, int]: +def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: """Compute the required padding. Args: @@ -171,7 +171,7 @@ def _fwt_pad( def _flatten_2d_coeff_lst( coeff_lst_2d: Sequence[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], flatten_tensors: bool = True, ) -> list[torch.Tensor]: @@ -202,7 +202,7 @@ def _flatten_2d_coeff_lst( def _adjust_padding_at_reconstruction( res_ll_size: int, coeff_size: int, pad_end: int, pad_start: int -) -> Tuple[int, int]: +) -> tuple[int, int]: pred_size = res_ll_size - (pad_start + pad_end) next_size = coeff_size if next_size == pred_size: @@ -218,14 +218,14 @@ def _adjust_padding_at_reconstruction( def _preprocess_tensor_dec1d( data: torch.Tensor, -) -> Tuple[torch.Tensor, Union[list[int], None]]: +) -> tuple[torch.Tensor, Union[list[int], None]]: """Preprocess input tensor dimensions. Args: data (torch.Tensor): An input tensor of any shape. Returns: - Tuple[torch.Tensor, Union[list[int], None]]: + tuple[torch.Tensor, Union[list[int], None]]: A data tensor of shape [new_batch, 1, to_process] and the original shape, if the shape has changed. """ @@ -251,7 +251,7 @@ def _postprocess_result_list_dec1d( def _preprocess_result_list_rec1d( result_lst: Sequence[torch.Tensor], -) -> Tuple[list[torch.Tensor], list[int]]: +) -> tuple[list[torch.Tensor], list[int]]: # Fold axes for the wavelets ds = list(result_lst[0].shape) fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index f66aed06..40d617a7 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from functools import partial -from typing import Optional, Tuple, Union, cast +from typing import Optional, Union, cast import pywt import torch @@ -97,9 +97,9 @@ def _fwt_pad2( def _waverec2d_fold_channels_2d_list( - coeffs: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], -) -> Tuple[ - list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], +) -> tuple[ + list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], list[int], ]: # fold the input coefficients for processing conv2d_transpose. @@ -109,7 +109,7 @@ def _waverec2d_fold_channels_2d_list( def _preprocess_tensor_dec2d( data: torch.Tensor, -) -> Tuple[torch.Tensor, Union[list[int], None]]: +) -> tuple[torch.Tensor, Union[list[int], None]]: # Preprocess multidimensional input. ds = None if len(data.shape) == 2: @@ -131,8 +131,8 @@ def wavedec2( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, - axes: Tuple[int, int] = (-2, -1), -) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + axes: tuple[int, int] = (-2, -1), +) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: r"""Run a two-dimensional wavelet transformation. This function relies on two-dimensional convolutions. @@ -167,7 +167,7 @@ def wavedec2( Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. level (int): The number of desired scales. Defaults to None. - axes (Tuple[int, int]): Compute the transform over these axes instead of the + axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). Returns: @@ -216,7 +216,7 @@ def wavedec2( level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_lst: list[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = [] res_ll = data for _ in range(level): @@ -240,9 +240,9 @@ def wavedec2( def waverec2( - coeffs: Sequence[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], wavelet: Union[Wavelet, str], - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. @@ -259,7 +259,7 @@ def waverec2( and D diagonal coefficients. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. - axes (Tuple[int, int]): Compute the transform over these axes instead of the + axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). Returns: diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index b94c2277..0be6bb63 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -4,7 +4,7 @@ """ from functools import partial -from typing import Dict, Optional, Sequence, Tuple, Union, cast +from typing import Dict, Optional, Sequence, Union, cast import pywt import torch @@ -107,7 +107,7 @@ def wavedec3( *, mode: BoundaryMode = "zero", level: Optional[int] = None, - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), ) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a three-dimensional wavelet transform. @@ -121,7 +121,7 @@ def wavedec3( Defaults to "zero". See :data:`ptwt.constants.BoundaryMode`. level (Optional[int]): The maximum decomposition level. This argument defaults to None. - axes (Tuple[int, int, int]): Compute the transform over these axes + axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). Returns: @@ -211,7 +211,7 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], -) -> Tuple[ +) -> tuple[ list[Union[torch.Tensor, Dict[str, torch.Tensor]]], list[int], ]: @@ -232,7 +232,7 @@ def _waverec3d_fold_channels_3d_list( def waverec3( coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[Wavelet, str], - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. @@ -240,7 +240,7 @@ def waverec3( coeffs (sequence): The wavelet coefficient sequence produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. - axes (Tuple[int, int, int]): Transform these axes instead of the + axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). Returns: diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index a4a65050..c2060665 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -6,7 +6,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import Optional, Tuple, Union, cast +from typing import Optional, Union, cast import numpy as np import torch @@ -211,7 +211,7 @@ def construct_boundary_s2( return orth_s -def _matrix_pad_2(height: int, width: int) -> Tuple[int, int, Tuple[bool, bool]]: +def _matrix_pad_2(height: int, width: int) -> tuple[int, int, tuple[bool, bool]]: pad_tuple = (False, False) if height % 2 != 0: height += 1 @@ -257,7 +257,7 @@ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): @@ -297,11 +297,11 @@ def __init__( self.level = level self.boundary = boundary self.separable = separable - self.input_signal_shape: Optional[Tuple[int, int]] = None + self.input_signal_shape: Optional[tuple[int, int]] = None self.fwt_matrix_list: list[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ] = [] - self.pad_list: list[Tuple[bool, bool]] = [] + self.pad_list: list[tuple[bool, bool]] = [] self.padded = False if not _is_boundary_mode_supported(self.boundary): @@ -417,7 +417,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call @@ -477,7 +477,7 @@ def __call__( ) split_list: list[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = [] if self.separable: ll = input_signal @@ -531,7 +531,7 @@ def __call__( coefficients, int(np.prod((size[0] // 2, size[1] // 2))) ) reshaped = cast( - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple( ( el.T.reshape( @@ -576,7 +576,7 @@ class MatrixWaverec2(object): def __init__( self, wavelet: Union[Wavelet, str], - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): @@ -614,10 +614,10 @@ def __init__( self.axes = axes self.ifwt_matrix_list: list[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ] = [] self.level: Optional[int] = None - self.input_signal_shape: Optional[Tuple[int, int]] = None + self.input_signal_shape: Optional[tuple[int, int]] = None self.padded = False @@ -728,7 +728,7 @@ def _construct_synthesis_matrices( def __call__( self, coefficients: Sequence[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 5f1d9de8..774e17df 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -3,7 +3,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import Dict, NamedTuple, Optional, Tuple, Union +from typing import Dict, NamedTuple, Optional, Union import numpy as np import torch @@ -37,7 +37,7 @@ class _PadTuple(NamedTuple): def _matrix_pad_3( depth: int, height: int, width: int -) -> Tuple[int, int, int, _PadTuple]: +) -> tuple[int, int, int, _PadTuple]: pad_depth, pad_height, pad_width = (False, False, False) if height % 2 != 0: height += 1 @@ -58,7 +58,7 @@ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), boundary: OrthogonalizeMethod = "qr", ): """Create a *separable* three-dimensional fast boundary wavelet transform. @@ -87,7 +87,7 @@ def __init__( else: _check_axes_argument(list(axes)) self.axes = axes - self.input_signal_shape: Optional[Tuple[int, int, int]] = None + self.input_signal_shape: Optional[tuple[int, int, int]] = None self.fwt_matrix_list: list[list[torch.Tensor]] = [] if not _is_boundary_mode_supported(self.boundary): @@ -278,7 +278,7 @@ class MatrixWaverec3(object): def __init__( self, wavelet: Union[Wavelet, str], - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), boundary: OrthogonalizeMethod = "qr", ): """Compute a three-dimensional separable boundary wavelet synthesis transform. @@ -286,7 +286,7 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. - axes (Tuple[int, int, int]): Transform these axes instead of the + axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). boundary : The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr @@ -306,7 +306,7 @@ def __init__( self.axes = axes self.boundary = boundary self.ifwt_matrix_list: list[list[torch.Tensor]] = [] - self.input_signal_shape: Optional[Tuple[int, int, int]] = None + self.input_signal_shape: Optional[tuple[int, int, int]] = None self.level: Optional[int] = None if not _is_boundary_mode_supported(self.boundary): diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index dbe5364a..170d4cb0 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from functools import partial from itertools import product -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, cast import numpy as np import pywt @@ -266,7 +266,7 @@ def __init__( wavelet: Union[Wavelet, str], mode: ExtendedBoundaryMode = "reflect", maxlevel: Optional[int] = None, - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = False, ) -> None: @@ -300,8 +300,8 @@ def __init__( self.mode = mode self.boundary = boundary_orthogonalization self.separable = separable - self.matrix_wavedec2_dict: Dict[Tuple[int, ...], MatrixWavedec2] = {} - self.matrix_waverec2_dict: Dict[Tuple[int, ...], MatrixWaverec2] = {} + self.matrix_wavedec2_dict: Dict[tuple[int, ...], MatrixWavedec2] = {} + self.matrix_waverec2_dict: Dict[tuple[int, ...], MatrixWaverec2] = {} self.axes = axes self.maxlevel: Optional[int] = None @@ -383,9 +383,9 @@ def get_natural_order(self, level: int) -> list[str]: """ return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] - def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[ + def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ [torch.Tensor], - list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ]: if self.mode == "boundary": shape = tuple(shape) @@ -414,8 +414,8 @@ def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[ wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[ - [list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], + def _get_waverec(self, shape: tuple[int, ...]) -> Callable[ + [list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], torch.Tensor, ]: if self.mode == "boundary": @@ -442,11 +442,11 @@ def _transform_fsdict_to_tuple_func( ], ) -> Callable[ [torch.Tensor], - list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ]: def _tuple_func( data: torch.Tensor, - ) -> list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: a_coeff, fsdict = fs_dict_func(data) fsdict = cast(Dict[str, torch.Tensor], fsdict) return [ @@ -462,12 +462,12 @@ def _transform_tuple_to_fsdict_func( [list[Union[torch.Tensor, Dict[str, torch.Tensor]]]], torch.Tensor ], ) -> Callable[ - [list[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], + [list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], torch.Tensor, ]: def _fsdict_func( coeffs: Sequence[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] ) -> torch.Tensor: a, (h, v, d) = coeffs @@ -520,7 +520,7 @@ def __getitem__(self, key: str) -> torch.Tensor: return super().__getitem__(key) -def get_freq_order(level: int) -> list[list[Tuple[str, ...]]]: +def get_freq_order(level: int) -> list[list[tuple[str, ...]]]: """Get the frequency order for a given packet decomposition level. Use this code to create two-dimensional frequency orderings. @@ -552,14 +552,14 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: ] return graycode_order - def _expand_2d_path(path: Tuple[str, ...]) -> Tuple[str, str]: + def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]: expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"} return ( "".join([expanded_paths[p][0] for p in path]), "".join([expanded_paths[p][1] for p in path]), ) - nodes_dict: Dict[str, Dict[str, Tuple[str, ...]]] = {} + nodes_dict: Dict[str, Dict[str, tuple[str, ...]]] = {} for (row_path, col_path), node in [ (_expand_2d_path(node), node) for node in wp_natural_path ]: diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 12ad9de0..915934e8 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -9,7 +9,7 @@ from collections.abc import Sequence from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import numpy as np import pywt @@ -180,7 +180,7 @@ def fswavedec2( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), ) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 2D-padded analysis wavelet transform. @@ -250,7 +250,7 @@ def fswavedec3( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), ) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 3D-padded analysis wavelet transform. @@ -264,7 +264,7 @@ def fswavedec3( Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. level (int): The number of desired scales. Defaults to None. - axes (Tuple[int, int, int]): Compute the transform over these axes + axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). Raises: @@ -318,7 +318,7 @@ def fswavedec3( def fswaverec2( coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], - axes: Tuple[int, int] = (-2, -1), + axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Compute a fully separable 2D-padded synthesis wavelet transform. @@ -330,7 +330,7 @@ def fswaverec2( The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. - axes (Tuple[int, int]): Compute the transform over these + axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). Returns: @@ -385,7 +385,7 @@ def fswaverec2( def fswaverec3( coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], - axes: Tuple[int, int, int] = (-3, -2, -1), + axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. @@ -394,7 +394,7 @@ def fswaverec3( The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. - axes (Tuple[int, int, int]): Compute the transform over these axes + axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). Returns: diff --git a/src/ptwt/wavelets_learnable.py b/src/ptwt/wavelets_learnable.py index 76c39036..3e0fa118 100644 --- a/src/ptwt/wavelets_learnable.py +++ b/src/ptwt/wavelets_learnable.py @@ -5,7 +5,6 @@ # Inspired by Ripples in Mathematics, Jensen and La Cour-Harbo, Chapter 7.7 from abc import ABC, abstractmethod -from typing import Tuple import torch @@ -22,7 +21,7 @@ class WaveletFilter(ABC): @abstractmethod def filter_bank( self, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Return dec_lo, dec_hi, rec_lo, rec_hi.""" raise NotImplementedError @@ -42,7 +41,7 @@ def __len__(self) -> int: def pf_alias_cancellation_loss( self, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the product filter-alias cancellation loss. See: Strang+Nguyen 105: F0(z) = H1(-z); F1(z) = -H0(-z) @@ -77,7 +76,7 @@ def pf_alias_cancellation_loss( def alias_cancellation_loss( self, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the alias cancellation loss. Implementation of the ac-loss as described @@ -119,7 +118,7 @@ def alias_cancellation_loss( def perfect_reconstruction_loss( self, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the perfect reconstruction loss. Returns: @@ -196,7 +195,7 @@ def __init__( @property def filter_bank( self, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Return all filters a a tuple.""" return self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi From 75c0bc0be21db0d4e61f4a426c0f730422d8ff07 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 12:47:50 +0200 Subject: [PATCH 04/40] Use builtin dict instead of Dict --- src/ptwt/conv_transform.py | 4 +-- src/ptwt/conv_transform_3.py | 17 +++++++------ src/ptwt/matmul_transform_3.py | 18 ++++++------- src/ptwt/packets.py | 18 ++++++------- src/ptwt/separable_conv_transform.py | 38 ++++++++++++++-------------- 5 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 838bc039..66c9798e 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -2,8 +2,8 @@ This module treats boundaries with edge-padding. """ - -from typing import Optional, Sequence, Union, cast +from collections.abc import Sequence +from typing import Optional, Union, cast import pywt import torch diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 0be6bb63..51224dc5 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -3,8 +3,9 @@ The functions here are based on torch.nn.functional.conv3d and it's transpose. """ +from collections.abc import Sequence from functools import partial -from typing import Dict, Optional, Sequence, Union, cast +from typing import Optional, Union, cast import pywt import torch @@ -108,7 +109,7 @@ def wavedec3( mode: BoundaryMode = "zero", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: """Compute a three-dimensional wavelet transform. Args: @@ -174,7 +175,7 @@ def wavedec3( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) - result_lst: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + result_lst: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: @@ -210,13 +211,13 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], ) -> tuple[ - list[Union[torch.Tensor, Dict[str, torch.Tensor]]], + list[Union[torch.Tensor, dict[str, torch.Tensor]]], list[int], ]: # fold the input coefficients for processing conv2d_transpose. - fold_coeffs: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + fold_coeffs: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) for coeff in coeffs: if isinstance(coeff, torch.Tensor): @@ -230,7 +231,7 @@ def _waverec3d_fold_channels_3d_list( def waverec3( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: @@ -292,7 +293,7 @@ def waverec3( filt_len = rec_lo.shape[-1] rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi) - coeff_dicts = cast(Sequence[Dict[str, torch.Tensor]], coeffs[1:]) + coeff_dicts = cast(Sequence[dict[str, torch.Tensor]], coeffs[1:]) for c_pos, coeff_dict in enumerate(coeff_dicts): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 774e17df..15aed691 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -3,7 +3,7 @@ import sys from collections.abc import Sequence from functools import partial -from typing import Dict, NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Union import numpy as np import torch @@ -158,7 +158,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + ) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: """Compute a separable 3d-boundary wavelet transform. Args: @@ -169,7 +169,7 @@ def __call__( ValueError: If the input dimensions don't work. Returns: - list[Union[torch.Tensor, TypedDict[str, torch.Tensor]]]: + list[Union[torch.Tensor, dict[str, torch.Tensor]]]: A list with the approximation coefficients, and a coefficient dict for each scale. """ @@ -219,7 +219,7 @@ def __call__( device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + split_list: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): # fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats @@ -239,7 +239,7 @@ def _split_rec( tensor: torch.Tensor, key: str, depth: int, - dict: Dict[str, torch.Tensor], + dict: dict[str, torch.Tensor], ) -> None: if key: dict[key] = tensor @@ -249,7 +249,7 @@ def _split_rec( _split_rec(ca, "a" + key, depth, dict) _split_rec(cd, "d" + key, depth, dict) - coeff_dict: Dict[str, torch.Tensor] = {} + coeff_dict: dict[str, torch.Tensor] = {} _split_rec(lll, "", 3, coeff_dict) lll = coeff_dict["aaa"] result_keys = list( @@ -370,7 +370,7 @@ def _construct_synthesis_matrices( current_width // 2, ) - def _cat_coeff_recursive(self, input_dict: Dict[str, torch.Tensor]) -> torch.Tensor: + def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Tensor: done_dict = {} a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys())) for a_key in a_initial_keys: @@ -387,12 +387,12 @@ def _cat_coeff_recursive(self, input_dict: Dict[str, torch.Tensor]) -> torch.Ten return self._cat_coeff_recursive(done_dict) def __call__( - self, coefficients: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]] + self, coefficients: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]] ) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: - coefficients (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coefficients (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): The output from MatrixWavedec3. Returns: diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 170d4cb0..054e7cb1 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from functools import partial from itertools import product -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, cast +from typing import TYPE_CHECKING, Callable, Optional, Union, cast import numpy as np import pywt @@ -97,8 +97,8 @@ def __init__( self.wavelet = _as_wavelet(wavelet) self.mode = mode self.boundary = boundary_orthogonalization - self._matrix_wavedec_dict: Dict[int, MatrixWavedec] = {} - self._matrix_waverec_dict: Dict[int, MatrixWaverec] = {} + self._matrix_wavedec_dict: dict[int, MatrixWavedec] = {} + self._matrix_waverec_dict: dict[int, MatrixWaverec] = {} self.maxlevel: Optional[int] = None self.axis = axis if data is not None: @@ -300,8 +300,8 @@ def __init__( self.mode = mode self.boundary = boundary_orthogonalization self.separable = separable - self.matrix_wavedec2_dict: Dict[tuple[int, ...], MatrixWavedec2] = {} - self.matrix_waverec2_dict: Dict[tuple[int, ...], MatrixWaverec2] = {} + self.matrix_wavedec2_dict: dict[tuple[int, ...], MatrixWavedec2] = {} + self.matrix_waverec2_dict: dict[tuple[int, ...], MatrixWaverec2] = {} self.axes = axes self.maxlevel: Optional[int] = None @@ -438,7 +438,7 @@ def _get_waverec(self, shape: tuple[int, ...]) -> Callable[ def _transform_fsdict_to_tuple_func( self, fs_dict_func: Callable[ - [torch.Tensor], list[Union[torch.Tensor, Dict[str, torch.Tensor]]] + [torch.Tensor], list[Union[torch.Tensor, dict[str, torch.Tensor]]] ], ) -> Callable[ [torch.Tensor], @@ -448,7 +448,7 @@ def _tuple_func( data: torch.Tensor, ) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: a_coeff, fsdict = fs_dict_func(data) - fsdict = cast(Dict[str, torch.Tensor], fsdict) + fsdict = cast(dict[str, torch.Tensor], fsdict) return [ cast(torch.Tensor, a_coeff), (fsdict["ad"], fsdict["da"], fsdict["dd"]), @@ -459,7 +459,7 @@ def _tuple_func( def _transform_tuple_to_fsdict_func( self, fsdict_func: Callable[ - [list[Union[torch.Tensor, Dict[str, torch.Tensor]]]], torch.Tensor + [list[Union[torch.Tensor, dict[str, torch.Tensor]]]], torch.Tensor ], ) -> Callable[ [list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], @@ -559,7 +559,7 @@ def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]: "".join([expanded_paths[p][1] for p in path]), ) - nodes_dict: Dict[str, Dict[str, tuple[str, ...]]] = {} + nodes_dict: dict[str, dict[str, tuple[str, ...]]] = {} for (row_path, col_path), node in [ (_expand_2d_path(node), node) for node in wp_natural_path ]: diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 915934e8..a3b58268 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -9,7 +9,7 @@ from collections.abc import Sequence from functools import partial -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np import pywt @@ -32,7 +32,7 @@ def _separable_conv_dwtn_( - rec_dict: Dict[str, torch.Tensor], + rec_dict: dict[str, torch.Tensor], input_arg: torch.Tensor, wavelet: Union[str, pywt.Wavelet], *, @@ -52,7 +52,7 @@ def _separable_conv_dwtn_( Defaults to "reflect". key (str): The filter application path. Defaults to "". - dict (Dict[str, torch.Tensor]): The result will be stored here + dict (dict[str, torch.Tensor]): The result will be stored here in place. Defaults to {}. """ axis_total = len(input_arg.shape) - 1 @@ -68,12 +68,12 @@ def _separable_conv_dwtn_( def _separable_conv_idwtn( - in_dict: Dict[str, torch.Tensor], wavelet: Union[str, pywt.Wavelet] + in_dict: dict[str, torch.Tensor], wavelet: Union[str, pywt.Wavelet] ) -> torch.Tensor: """Separable single level inverse fast wavelet transform. Args: - in_dict (Dict[str, torch.Tensor]): The dictionary produced + in_dict (dict[str, torch.Tensor]): The dictionary produced by _separable_conv_dwtn_ . wavelet (Union[str, pywt.Wavelet]): The wavelet used by _separable_conv_dwtn_ . @@ -112,7 +112,7 @@ def _separable_conv_wavedecn( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, -) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: """Compute a multilevel separable padded wavelet analysis transform. Args: @@ -122,9 +122,9 @@ def _separable_conv_wavedecn( level (int): The desired decomposition level. Returns: - list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: The wavelet coeffs. + list[Union[torch.Tensor, dict[str, torch.Tensor]]]: The wavelet coeffs. """ - result: list[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + result: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] approx = input if level is None: @@ -134,7 +134,7 @@ def _separable_conv_wavedecn( ) for _ in range(level): - level_dict: Dict[str, torch.Tensor] = {} + level_dict: dict[str, torch.Tensor] = {} _separable_conv_dwtn_(level_dict, approx, wavelet, mode=mode, key="") approx_key = "a" * (len(input.shape) - 1) approx = level_dict.pop(approx_key) @@ -144,13 +144,13 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], wavelet: pywt.Wavelet, ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: - coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): The output as produced by `_separable_conv_wavedecn`. wavelet (pywt.Wavelet): The wavelet used by `_separable_conv_wavedecn`. @@ -181,7 +181,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -202,7 +202,7 @@ def fswavedec2( ValueError: If the data is not a batched 2D signal. Returns: - list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + list[Union[torch.Tensor, dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: @@ -251,7 +251,7 @@ def fswavedec3( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: +) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: """Compute a fully separable 3D-padded analysis wavelet transform. Args: @@ -271,7 +271,7 @@ def fswavedec3( ValueError: If the input is not a batched 3D signal. Returns: - list[Union[torch.Tensor, Dict[str, torch.Tensor]]]: + list[Union[torch.Tensor, dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: @@ -316,7 +316,7 @@ def fswavedec3( def fswaverec2( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -326,7 +326,7 @@ def fswaverec2( the hood. Args: - coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. @@ -383,14 +383,14 @@ def fswaverec2( def fswaverec3( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeffs (Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. From 69c69a2d752431c25931626faa4a3a2ac6dfe4df Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 22:27:55 +0200 Subject: [PATCH 05/40] Change return types to tuple --- src/ptwt/_util.py | 6 +++- src/ptwt/conv_transform_2.py | 38 +++++++++++------------- src/ptwt/conv_transform_3.py | 25 ++++++++-------- src/ptwt/matmul_transform_2.py | 31 +++++++++----------- src/ptwt/matmul_transform_3.py | 26 +++++++++------- src/ptwt/separable_conv_transform.py | 44 +++++++++++++++------------- 6 files changed, 87 insertions(+), 83 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index e3cb61f1..968d8758 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -2,7 +2,8 @@ from collections.abc import Sequence import typing -from typing import Any, Callable, Optional, Protocol, Union +from typing import Any, Callable, Optional, Protocol, Union, overload +from typing_extensions import Unpack import numpy as np import pywt @@ -29,6 +30,9 @@ def __len__(self) -> int: """Return the number of filter coefficients.""" return len(self.dec_lo) +WaveletTransformReturn2d = tuple[torch.Tensor, Unpack[tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]]] +WaveletTransformReturn3d = tuple[torch.Tensor, Unpack[tuple[dict[str, torch.Tensor], ...]]] + def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: """Ensure the input argument to be a pywt wavelet compatible object. diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 40d617a7..bffa2e07 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -13,6 +13,7 @@ from ._util import ( Wavelet, + WaveletTransformReturn2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -96,12 +97,7 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list( - coeffs: Sequence[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], -) -> tuple[ - list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], - list[int], -]: +def _waverec2d_fold_channels_2d_list(coeffs: WaveletTransformReturn2d) -> tuple[WaveletTransformReturn2d, list[int]]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds @@ -132,7 +128,7 @@ def wavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: +) -> WaveletTransformReturn2d: r"""Run a two-dimensional wavelet transformation. This function relies on two-dimensional convolutions. @@ -171,13 +167,13 @@ def wavedec2( last two. Defaults to (-2, -1). Returns: - list: A list containing the wavelet coefficients. + WaveletTransformReturn2d: A tuple containing the wavelet coefficients. The coefficients are in pywt order. That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . - A denotes approximation, H horizontal, V vertical - and D diagonal coefficients. + 'A' denotes approximation, 'H' horizontal, 'V' vertical + and 'D' diagonal coefficients. Raises: ValueError: If the dimensionality or the dtype of the input data tensor @@ -215,9 +211,7 @@ def wavedec2( if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) - result_lst: list[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = [] + result_lst: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] res_ll = data for _ in range(level): res_ll = _fwt_pad2(res_ll, wavelet, mode=mode) @@ -225,22 +219,24 @@ def wavedec2( res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1)) result_lst.append(to_append) - result_lst.append(res_ll.squeeze(1)) + result_lst.reverse() + res_ll = res_ll.squeeze(1) + result = res_ll, *result_lst if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - result_lst = _map_result(result_lst, _unfold_axes2) + result = _map_result(result, _unfold_axes2) if axes != (-2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result_lst = _map_result(result_lst, undo_swap_fn) + result = _map_result(result, undo_swap_fn) - return result_lst + return result def waverec2( - coeffs: Sequence[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + coeffs: WaveletTransformReturn2d, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -250,13 +246,13 @@ def waverec2( or forward transform by running transposed convolutions. Args: - coeffs (sequence): The wavelet coefficient sequence produced by wavedec2. + coeffs (WaveletTransformReturn2d): The wavelet coefficient tupl produced by wavedec2. The coefficients must be in pywt order. That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . - A denotes approximation, H horizontal, V vertical, - and D diagonal coefficients. + 'A' denotes approximation, 'H' horizontal, 'V' vertical, + and 'D' diagonal coefficients. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axes (tuple[int, int]): Compute the transform over these axes instead of the diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 51224dc5..c358017f 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -12,6 +12,7 @@ from ._util import ( Wavelet, + WaveletTransformReturn3d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -109,7 +110,7 @@ def wavedec3( mode: BoundaryMode = "zero", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: +) -> WaveletTransformReturn3d: """Compute a three-dimensional wavelet transform. Args: @@ -126,13 +127,13 @@ def wavedec3( instead of the last three. Defaults to (-3, -2, -1). Returns: - list: A list with the lll coefficients and dictionaries - with the filter order strings:: + WaveletTransformReturn3d: A tuple with the lll coefficients and + dictionaries with the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") - as keys. With a for the low pass or approximation filter and - d for the high-pass or detail filter. + as keys. With 'a' for the low pass or approximation filter and + 'd' for the high-pass or detail filter. Raises: ValueError: If the input has fewer than three dimensions or @@ -196,22 +197,22 @@ def wavedec3( "ddd": res_hhh, } ) - result_lst.append(res_lll) result_lst.reverse() + result = res_lll, *result_lst if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - result_lst = _map_result(result_lst, _unfold_axes_fn) + result = _map_result(result, _unfold_axes_fn) if tuple(axes) != (-3, -2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result_lst = _map_result(result_lst, undo_swap_fn) + result = _map_result(result, undo_swap_fn) - return result_lst + return result def _waverec3d_fold_channels_3d_list( - coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], + coeffs: WaveletTransformReturn3d, ) -> tuple[ list[Union[torch.Tensor, dict[str, torch.Tensor]]], list[int], @@ -231,14 +232,14 @@ def _waverec3d_fold_channels_3d_list( def waverec3( - coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], + coeffs: WaveletTransformReturn3d, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. Args: - coeffs (sequence): The wavelet coefficient sequence produced by wavedec3. + coeffs (WaveletTransformReturn3d): The wavelet coefficient tuple produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axes (tuple[int, int, int]): Transform these axes instead of the diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index c2060665..3d7348dd 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -13,6 +13,7 @@ from ._util import ( Wavelet, + WaveletTransformReturn2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -417,7 +418,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + ) -> WaveletTransformReturn2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call @@ -431,8 +432,8 @@ def __call__( This transform affects the last two dimensions. Returns: - (list): The resulting coefficients per level are stored in - a pywt style list. The list is ordered as:: + (WaveletTransformReturn2d): The resulting coefficients per level are stored in + a pywt style tuple. The tuple is ordered as:: (ll, (lh, hl, hh), ...) @@ -476,9 +477,7 @@ def __call__( device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = [] + split_list: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] if self.separable: ll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): @@ -501,7 +500,6 @@ def __call__( hl, hh = torch.split(d_coeffs, current_width // 2, dim=-1) split_list.append((lh, hl, hh)) - split_list.append(ll) else: ll = input_signal.transpose(-2, -1).reshape([batch_size, -1]).T for scale, fwt_matrix in enumerate(self.fwt_matrix_list): @@ -543,19 +541,20 @@ def __call__( ) split_list.append(reshaped) ll = four_split[0] - split_list.append( - ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) - ) + ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) + + split_list.reverse() + result = ll, *split_list if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - split_list = _map_result(split_list, _unfold_axes2) + result = _map_result(result, _unfold_axes2) if self.axes != (-2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - split_list = _map_result(split_list, undo_swap_fn) + result = _map_result(result, undo_swap_fn) - return split_list[::-1] + return result class MatrixWaverec2(object): @@ -727,14 +726,12 @@ def _construct_synthesis_matrices( def __call__( self, - coefficients: Sequence[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ], + coefficients: WaveletTransformReturn2d, ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. Args: - coefficients (list): The coefficient list as returned + coefficients (WaveletTransformReturn2d): The coefficient tuple as returned by the `MatrixWavedec2`-Object. Returns: diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 15aed691..fb449e69 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -1,7 +1,6 @@ """Implement 3D separable boundary transforms.""" import sys -from collections.abc import Sequence from functools import partial from typing import NamedTuple, Optional, Union @@ -10,6 +9,7 @@ from ._util import ( Wavelet, + WaveletTransformReturn3d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -158,7 +158,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: + ) -> WaveletTransformReturn3d: """Compute a separable 3d-boundary wavelet transform. Args: @@ -169,8 +169,8 @@ def __call__( ValueError: If the input dimensions don't work. Returns: - list[Union[torch.Tensor, dict[str, torch.Tensor]]]: - A list with the approximation coefficients, + WaveletTransformReturn3d: + A tuple with the approximation coefficients, and a coefficient dict for each scale. """ if self.axes != (-3, -2, -1): @@ -259,17 +259,19 @@ def _split_rec( key: tensor for key, tensor in coeff_dict.items() if key in result_keys } split_list.append(coeff_dict) - split_list.append(lll) + + split_list.reverse() + result = lll, *split_list if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - split_list = _map_result(split_list, _unfold_axes_fn) + result = _map_result(result, _unfold_axes_fn) if self.axes != (-3, -2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - split_list = _map_result(split_list, undo_swap_fn) + result = _map_result(result, undo_swap_fn) - return split_list[::-1] + return result class MatrixWaverec3(object): @@ -387,13 +389,15 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten return self._cat_coeff_recursive(done_dict) def __call__( - self, coefficients: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]] + self, coefficients: WaveletTransformReturn3d ) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: - coefficients (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): - The output from MatrixWavedec3. + coefficients (WaveletTransformReturn3d): + The output from MatrixWavedec3, consisting of a tuple + of the approximation coefficients and a dict with the + detail coefficients for each scale. Returns: torch.Tensor: A reconstruction of the original signal. diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index a3b58268..a23cb8d3 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -7,7 +7,6 @@ using torch.nn.functional.conv1d and it's transpose. """ -from collections.abc import Sequence from functools import partial from typing import Optional, Union @@ -16,6 +15,8 @@ import torch from ._util import ( + WaveletTransformReturn2d, + WaveletTransformReturn3d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -112,7 +113,7 @@ def _separable_conv_wavedecn( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, -) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: +) -> WaveletTransformReturn3d: """Compute a multilevel separable padded wavelet analysis transform. Args: @@ -122,9 +123,10 @@ def _separable_conv_wavedecn( level (int): The desired decomposition level. Returns: - list[Union[torch.Tensor, dict[str, torch.Tensor]]]: The wavelet coeffs. + WaveletTransformReturn3d: A tuple with the approximation coefficients, + and a coefficient dict for each scale. """ - result: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] + result: list[dict[str, torch.Tensor]] = [] approx = input if level is None: @@ -139,18 +141,18 @@ def _separable_conv_wavedecn( approx_key = "a" * (len(input.shape) - 1) approx = level_dict.pop(approx_key) result.append(level_dict) - result.append(approx) - return result[::-1] + result.reverse() + return approx, *result def _separable_conv_waverecn( - coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], + coeffs: WaveletTransformReturn3d, wavelet: pywt.Wavelet, ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: - coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): + coeffs (WaveletTransformReturn3d): The output as produced by `_separable_conv_wavedecn`. wavelet (pywt.Wavelet): The wavelet used by `_separable_conv_wavedecn`. @@ -168,9 +170,9 @@ def _separable_conv_waverecn( approx: torch.Tensor = coeffs[0] for level_dict in coeffs[1:]: - keys = list(level_dict.keys()) # type: ignore - level_dict["a" * max(map(len, keys))] = approx # type: ignore - approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore + keys = list(level_dict.keys()) + level_dict["a" * max(map(len, keys))] = approx + approx = _separable_conv_idwtn(level_dict, wavelet) return approx @@ -181,7 +183,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: +) -> WaveletTransformReturn3d: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -202,8 +204,8 @@ def fswavedec2( ValueError: If the data is not a batched 2D signal. Returns: - list[Union[torch.Tensor, dict[str, torch.Tensor]]]: - A list with the lll coefficients and dictionaries + WaveletTransformReturn3d: + A tuple with the lll coefficients and dictionaries with the filter order strings:: ("ad", "da", "dd") @@ -251,7 +253,7 @@ def fswavedec3( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> list[Union[torch.Tensor, dict[str, torch.Tensor]]]: +) -> WaveletTransformReturn3d: """Compute a fully separable 3D-padded analysis wavelet transform. Args: @@ -271,8 +273,8 @@ def fswavedec3( ValueError: If the input is not a batched 3D signal. Returns: - list[Union[torch.Tensor, dict[str, torch.Tensor]]]: - A list with the lll coefficients and dictionaries + WaveletTransformReturn3d: + A tuple with the lll coefficients and dictionaries with the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") @@ -316,7 +318,7 @@ def fswavedec3( def fswaverec2( - coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], + coeffs: WaveletTransformReturn3d, wavelet: Union[str, pywt.Wavelet], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -326,7 +328,7 @@ def fswaverec2( the hood. Args: - coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): + coeffs (WaveletTransformReturn3d): The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. @@ -383,14 +385,14 @@ def fswaverec2( def fswaverec3( - coeffs: Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]], + coeffs: WaveletTransformReturn3d, wavelet: Union[str, pywt.Wavelet], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeffs (Sequence[Union[torch.Tensor, dict[str, torch.Tensor]]]): + coeffs (WaveletTransformReturn3d): The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. From 3126c89c83b74aad7beb3db7047d71cc73a11ba8 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 22:28:29 +0200 Subject: [PATCH 06/40] Add function overloads --- src/ptwt/_util.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 968d8758..9aa8d36d 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -169,10 +169,26 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: return torch.permute(data, restore_sorted) +@overload def _map_result( - data: Sequence[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any - function: Callable[[Any], torch.Tensor], -) -> list[Union[torch.Tensor, Any]]: + data: WaveletTransformReturn2d, + function: Callable[[torch.Tensor], torch.Tensor], +) -> WaveletTransformReturn2d: + ... + + +@overload +def _map_result( + data: WaveletTransformReturn3d, + function: Callable[[torch.Tensor], torch.Tensor], +) -> WaveletTransformReturn3d: + ... + + +def _map_result( + data: Union[WaveletTransformReturn2d, WaveletTransformReturn3d], + function: Callable[[torch.Tensor], torch.Tensor] +) -> Union[WaveletTransformReturn2d, WaveletTransformReturn3d]: # Apply the given function to the input list of tensor and tuples. result_lst: list[Union[torch.Tensor, Any]] = [] for element in data: From 46198d0ad44d4fa8b1ab14a4f179c5172e373b4c Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 22:29:03 +0200 Subject: [PATCH 07/40] refactor _map_result --- src/ptwt/_util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 9aa8d36d..be2fa8ad 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -190,17 +190,19 @@ def _map_result( function: Callable[[torch.Tensor], torch.Tensor] ) -> Union[WaveletTransformReturn2d, WaveletTransformReturn3d]: # Apply the given function to the input list of tensor and tuples. - result_lst: list[Union[torch.Tensor, Any]] = [] + result_lst = [] + return_tuple = isinstance(data, tuple) for element in data: if isinstance(element, torch.Tensor): result_lst.append(function(element)) elif isinstance(element, tuple): - result_lst.append( - (function(element[0]), function(element[1]), function(element[2])) - ) + result_lst.append(tuple(map(function, element))) elif isinstance(element, dict): - new_dict = {} - for key, value in element.items(): - new_dict[key] = function(value) + new_dict = { + key: function(value) + for key, value in element.items() + } result_lst.append(new_dict) + if return_tuple: + return tuple(result_lst) return result_lst From 168b4a07e8c514fe6dbd909335fb4b5749034e4a Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 23:18:06 +0200 Subject: [PATCH 08/40] More adaptions to the new types --- src/ptwt/conv_transform_2.py | 2 +- src/ptwt/conv_transform_3.py | 26 +++++++-------- src/ptwt/matmul_transform_2.py | 2 +- src/ptwt/matmul_transform_3.py | 4 +-- src/ptwt/packets.py | 59 +++++++++++++--------------------- 5 files changed, 39 insertions(+), 54 deletions(-) diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index bffa2e07..cf87366b 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -222,7 +222,7 @@ def wavedec2( result_lst.reverse() res_ll = res_ll.squeeze(1) - result = res_ll, *result_lst + result: WaveletTransformReturn2d = res_ll, *result_lst if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index c358017f..77ae26b4 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -176,7 +176,7 @@ def wavedec3( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) - result_lst: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] + result_lst: list[dict[str, torch.Tensor]] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: @@ -198,7 +198,7 @@ def wavedec3( } ) result_lst.reverse() - result = res_lll, *result_lst + result: WaveletTransformReturn3d = res_lll, *result_lst if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) @@ -214,21 +214,21 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( coeffs: WaveletTransformReturn3d, ) -> tuple[ - list[Union[torch.Tensor, dict[str, torch.Tensor]]], + WaveletTransformReturn3d, list[int], ]: # fold the input coefficients for processing conv2d_transpose. - fold_coeffs: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] + fold_approx_coeff = _fold_axes(coeffs[0], 3)[0] + fold_coeffs: list[dict[str, torch.Tensor]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) - for coeff in coeffs: - if isinstance(coeff, torch.Tensor): - fold_coeffs.append(_fold_axes(coeff, 3)[0]) - else: - new_dict = {} - for key, value in coeff.items(): - new_dict[key] = _fold_axes(value, 3)[0] - fold_coeffs.append(new_dict) - return fold_coeffs, ds + fold_coeffs = [ + { + key: _fold_axes(value, 3)[0] + for key, value in coeff.items() + } + for coeff in coeffs[1:] + ] + return (fold_approx_coeff, *fold_coeffs), ds def waverec3( diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 3d7348dd..7210798c 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -544,7 +544,7 @@ def __call__( ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) split_list.reverse() - result = ll, *split_list + result: WaveletTransformReturn2d = ll, *split_list if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index fb449e69..515f341f 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -219,7 +219,7 @@ def __call__( device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [] + split_list: list[dict[str, torch.Tensor]] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): # fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats @@ -261,7 +261,7 @@ def _split_rec( split_list.append(coeff_dict) split_list.reverse() - result = lll, *split_list + result: WaveletTransformReturn3d = lll, *split_list if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 054e7cb1..54da3358 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -10,7 +10,7 @@ import pywt import torch -from ._util import Wavelet, _as_wavelet +from ._util import Wavelet, WaveletTransformReturn2d, WaveletTransformReturn3d, _as_wavelet from .constants import ExtendedBoundaryMode, OrthogonalizeMethod from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 @@ -356,7 +356,7 @@ def reconstruct(self) -> "WaveletPacket2D": data_v = self[node + "v"] data_d = self[node + "d"] rec = self._get_waverec(data_a.shape[-2:])( - [data_a, (data_h, data_v, data_d)] + (data_a, (data_h, data_v, data_d)) ) if level > 0: if rec.shape[-1] != self[node].shape[-1]: @@ -384,8 +384,7 @@ def get_natural_order(self, level: int) -> list[str]: return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ - [torch.Tensor], - list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + [torch.Tensor], WaveletTransformReturn2d, ]: if self.mode == "boundary": shape = tuple(shape) @@ -414,10 +413,7 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec(self, shape: tuple[int, ...]) -> Callable[ - [list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], - torch.Tensor, - ]: + def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletTransformReturn2d], torch.Tensor]: if self.mode == "boundary": shape = tuple(shape) if shape not in self.matrix_waverec2_dict.keys(): @@ -437,41 +433,30 @@ def _get_waverec(self, shape: tuple[int, ...]) -> Callable[ def _transform_fsdict_to_tuple_func( self, - fs_dict_func: Callable[ - [torch.Tensor], list[Union[torch.Tensor, dict[str, torch.Tensor]]] - ], - ) -> Callable[ - [torch.Tensor], - list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], - ]: + fs_dict_func: Callable[[torch.Tensor], WaveletTransformReturn3d], + ) -> Callable[[torch.Tensor], WaveletTransformReturn2d]: def _tuple_func( data: torch.Tensor, - ) -> list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: - a_coeff, fsdict = fs_dict_func(data) - fsdict = cast(dict[str, torch.Tensor], fsdict) - return [ - cast(torch.Tensor, a_coeff), - (fsdict["ad"], fsdict["da"], fsdict["dd"]), - ] + ) -> WaveletTransformReturn2d: + fs_dict_data = fs_dict_func(data) + # assert for type checking + assert len(fs_dict_data) == 2 + a_coeff, fsdict = fs_dict_data + return (a_coeff, (fsdict["ad"], fsdict["da"], fsdict["dd"])) return _tuple_func def _transform_tuple_to_fsdict_func( self, - fsdict_func: Callable[ - [list[Union[torch.Tensor, dict[str, torch.Tensor]]]], torch.Tensor - ], - ) -> Callable[ - [list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], - torch.Tensor, - ]: + fsdict_func: Callable[[WaveletTransformReturn3d], torch.Tensor], + ) -> Callable[[WaveletTransformReturn2d], torch.Tensor]: def _fsdict_func( - coeffs: Sequence[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] + coeffs: WaveletTransformReturn2d ) -> torch.Tensor: + # assert for type checking + assert len(coeffs) == 2 a, (h, v, d) = coeffs - return fsdict_func([cast(torch.Tensor, a), {"ad": h, "da": v, "dd": d}]) + return fsdict_func((a, {"ad": h, "da": v, "dd": d})) return _fsdict_func @@ -481,11 +466,11 @@ def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None: self.data[path] = data if level < self.maxlevel: - result_a, (result_h, result_v, result_d) = self._get_wavedec( - data.shape[-2:] - )(data) + result = self._get_wavedec(data.shape[-2:])(data) + # assert for type checking - assert not isinstance(result_a, tuple) + assert len(result) == 2 + result_a, (result_h, result_v, result_d) = result self._recursive_dwt2d(result_a, level + 1, path + "a") self._recursive_dwt2d(result_h, level + 1, path + "h") self._recursive_dwt2d(result_v, level + 1, path + "v") From 09811b09abbe580f70f74be9061fa2ae882f5014 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 23:26:33 +0200 Subject: [PATCH 09/40] Fix _map_result --- src/ptwt/_util.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index be2fa8ad..ff5421bf 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -2,7 +2,7 @@ from collections.abc import Sequence import typing -from typing import Any, Callable, Optional, Protocol, Union, overload +from typing import Any, Callable, Optional, Protocol, Union, cast, overload from typing_extensions import Unpack import numpy as np @@ -189,20 +189,32 @@ def _map_result( data: Union[WaveletTransformReturn2d, WaveletTransformReturn3d], function: Callable[[torch.Tensor], torch.Tensor] ) -> Union[WaveletTransformReturn2d, WaveletTransformReturn3d]: - # Apply the given function to the input list of tensor and tuples. - result_lst = [] return_tuple = isinstance(data, tuple) - for element in data: - if isinstance(element, torch.Tensor): - result_lst.append(function(element)) - elif isinstance(element, tuple): - result_lst.append(tuple(map(function, element))) + approx = function(data[0]) + result_lst: list[ + Union[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor], + dict[str, torch.Tensor], + ] + ] = [] + for element in data[1:]: + if isinstance(element, tuple): + result_lst.append( + ( + function(element[0]), + function(element[1]), + function(element[2]), + ) + ) elif isinstance(element, dict): new_dict = { key: function(value) for key, value in element.items() } result_lst.append(new_dict) - if return_tuple: - return tuple(result_lst) - return result_lst + else: + raise AssertionError(f"Unexpected input type {type(element)}") + + return_val = approx, *result_lst + return_val = cast(Union[WaveletTransformReturn2d, WaveletTransformReturn3d], return_val) + return return_val From cb1b3db0ce8cd851479e8dabe23134f77b95bad2 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 3 Jun 2024 23:35:11 +0200 Subject: [PATCH 10/40] Improve wavelet args for separable transforms --- src/ptwt/separable_conv_transform.py | 34 ++++++++++++++++------------ 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index a23cb8d3..f5efa6b3 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -15,6 +15,7 @@ import torch from ._util import ( + Wavelet, WaveletTransformReturn2d, WaveletTransformReturn3d, _as_wavelet, @@ -35,7 +36,7 @@ def _separable_conv_dwtn_( rec_dict: dict[str, torch.Tensor], input_arg: torch.Tensor, - wavelet: Union[str, pywt.Wavelet], + wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", key: str = "", @@ -46,7 +47,8 @@ def _separable_conv_dwtn_( Args: input_arg (torch.Tensor): Tensor of shape [batch, data_1, ... data_n]. - wavelet (Union[str, pywt.Wavelet]): The Wavelet to work with. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. mode : The padding mode. The following methods are supported:: "reflect", "zero", "constant", "periodic". @@ -69,14 +71,15 @@ def _separable_conv_dwtn_( def _separable_conv_idwtn( - in_dict: dict[str, torch.Tensor], wavelet: Union[str, pywt.Wavelet] + in_dict: dict[str, torch.Tensor], wavelet: Union[Wavelet, str] ) -> torch.Tensor: """Separable single level inverse fast wavelet transform. Args: in_dict (dict[str, torch.Tensor]): The dictionary produced by _separable_conv_dwtn_ . - wavelet (Union[str, pywt.Wavelet]): The wavelet used by + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet, as used by _separable_conv_dwtn_ . Returns: @@ -109,7 +112,7 @@ def _separable_conv_idwtn( def _separable_conv_wavedecn( input: torch.Tensor, - wavelet: pywt.Wavelet, + wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, @@ -118,7 +121,8 @@ def _separable_conv_wavedecn( Args: input (torch.Tensor): A tensor i.e. of shape [batch,axis_1, ... axis_n]. - wavelet (Wavelet): A pywt wavelet compatible object. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. mode : The desired padding mode. level (int): The desired decomposition level. @@ -147,15 +151,15 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( coeffs: WaveletTransformReturn3d, - wavelet: pywt.Wavelet, + wavelet: Union[Wavelet, str], ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: coeffs (WaveletTransformReturn3d): The output as produced by `_separable_conv_wavedecn`. - wavelet (pywt.Wavelet): - The wavelet used by `_separable_conv_wavedecn`. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet, as used by `_separable_conv_wavedecn`. Returns: torch.Tensor: The reconstruction of the original signal. @@ -178,7 +182,7 @@ def _separable_conv_waverecn( def fswavedec2( data: torch.Tensor, - wavelet: Union[str, pywt.Wavelet], + wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, @@ -248,7 +252,7 @@ def fswavedec2( def fswavedec3( data: torch.Tensor, - wavelet: Union[str, pywt.Wavelet], + wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, @@ -319,7 +323,7 @@ def fswavedec3( def fswaverec2( coeffs: WaveletTransformReturn3d, - wavelet: Union[str, pywt.Wavelet], + wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Compute a fully separable 2D-padded synthesis wavelet transform. @@ -330,7 +334,7 @@ def fswaverec2( Args: coeffs (WaveletTransformReturn3d): The wavelet coefficients as computed by `fswavedec2`. - wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the + wavelet (Wavelet or str): The wavelet to use for the synthesis transform. axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). @@ -386,7 +390,7 @@ def fswaverec2( def fswaverec3( coeffs: WaveletTransformReturn3d, - wavelet: Union[str, pywt.Wavelet], + wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. @@ -394,7 +398,7 @@ def fswaverec3( Args: coeffs (WaveletTransformReturn3d): The wavelet coefficients as computed by `fswavedec3`. - wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the + wavelet (Wavelet or str): The wavelet to use for the synthesis transform. axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). From 310b1af58f8e8542206ba3f963b42a845e6ba321 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:02:23 +0200 Subject: [PATCH 11/40] Rename wavelet coefficient types. Renamed * `WaveletTransformReturn3d` to `WaveletCoeffDetailDict`, * `WaveletTransformReturn2d` to `WaveletCoeffDetailTuple2d` and added the type aliases * `WaveletDetailDict` for `dict[str, Tensor]` * `WaveletDetailTuple2d` for `tuple[Tensor, Tensor, Tensor]` --- src/ptwt/_util.py | 21 ++++++++++++--------- src/ptwt/conv_transform_2.py | 10 +++++----- src/ptwt/conv_transform_3.py | 12 ++++++------ src/ptwt/matmul_transform_2.py | 8 ++++---- src/ptwt/matmul_transform_3.py | 8 ++++---- src/ptwt/packets.py | 18 +++++++++--------- src/ptwt/separable_conv_transform.py | 16 ++++++++-------- 7 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index ff5421bf..d8d597cf 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -30,8 +30,11 @@ def __len__(self) -> int: """Return the number of filter coefficients.""" return len(self.dec_lo) -WaveletTransformReturn2d = tuple[torch.Tensor, Unpack[tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]]] -WaveletTransformReturn3d = tuple[torch.Tensor, Unpack[tuple[dict[str, torch.Tensor], ...]]] +WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor] +WaveletDetailDict = dict[str, torch.Tensor] + +WaveletCoeffDetailTuple2d = tuple[torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]] +WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: @@ -171,24 +174,24 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: @overload def _map_result( - data: WaveletTransformReturn2d, + data: WaveletCoeffDetailTuple2d, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletTransformReturn2d: +) -> WaveletCoeffDetailTuple2d: ... @overload def _map_result( - data: WaveletTransformReturn3d, + data: WaveletCoeffDetailDict, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletTransformReturn3d: +) -> WaveletCoeffDetailDict: ... def _map_result( - data: Union[WaveletTransformReturn2d, WaveletTransformReturn3d], + data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], function: Callable[[torch.Tensor], torch.Tensor] -) -> Union[WaveletTransformReturn2d, WaveletTransformReturn3d]: +) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]: return_tuple = isinstance(data, tuple) approx = function(data[0]) result_lst: list[ @@ -216,5 +219,5 @@ def _map_result( raise AssertionError(f"Unexpected input type {type(element)}") return_val = approx, *result_lst - return_val = cast(Union[WaveletTransformReturn2d, WaveletTransformReturn3d], return_val) + return_val = cast(Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val) return return_val diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index cf87366b..229c910f 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -13,7 +13,7 @@ from ._util import ( Wavelet, - WaveletTransformReturn2d, + WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -97,7 +97,7 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list(coeffs: WaveletTransformReturn2d) -> tuple[WaveletTransformReturn2d, list[int]]: +def _waverec2d_fold_channels_2d_list(coeffs: WaveletCoeffDetailTuple2d) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds @@ -128,7 +128,7 @@ def wavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> WaveletTransformReturn2d: +) -> WaveletCoeffDetailTuple2d: r"""Run a two-dimensional wavelet transformation. This function relies on two-dimensional convolutions. @@ -222,7 +222,7 @@ def wavedec2( result_lst.reverse() res_ll = res_ll.squeeze(1) - result: WaveletTransformReturn2d = res_ll, *result_lst + result: WaveletCoeffDetailTuple2d = res_ll, *result_lst if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) @@ -236,7 +236,7 @@ def wavedec2( def waverec2( - coeffs: WaveletTransformReturn2d, + coeffs: WaveletCoeffDetailTuple2d, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 77ae26b4..a9debf59 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -12,7 +12,7 @@ from ._util import ( Wavelet, - WaveletTransformReturn3d, + WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -110,7 +110,7 @@ def wavedec3( mode: BoundaryMode = "zero", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> WaveletTransformReturn3d: +) -> WaveletCoeffDetailDict: """Compute a three-dimensional wavelet transform. Args: @@ -198,7 +198,7 @@ def wavedec3( } ) result_lst.reverse() - result: WaveletTransformReturn3d = res_lll, *result_lst + result: WaveletCoeffDetailDict = res_lll, *result_lst if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) @@ -212,9 +212,9 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( - coeffs: WaveletTransformReturn3d, + coeffs: WaveletCoeffDetailDict, ) -> tuple[ - WaveletTransformReturn3d, + WaveletCoeffDetailDict, list[int], ]: # fold the input coefficients for processing conv2d_transpose. @@ -232,7 +232,7 @@ def _waverec3d_fold_channels_3d_list( def waverec3( - coeffs: WaveletTransformReturn3d, + coeffs: WaveletCoeffDetailDict, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 7210798c..f1ff9157 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -13,7 +13,7 @@ from ._util import ( Wavelet, - WaveletTransformReturn2d, + WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -418,7 +418,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> WaveletTransformReturn2d: + ) -> WaveletCoeffDetailTuple2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call @@ -544,7 +544,7 @@ def __call__( ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) split_list.reverse() - result: WaveletTransformReturn2d = ll, *split_list + result: WaveletCoeffDetailTuple2d = ll, *split_list if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) @@ -726,7 +726,7 @@ def _construct_synthesis_matrices( def __call__( self, - coefficients: WaveletTransformReturn2d, + coefficients: WaveletCoeffDetailTuple2d, ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 515f341f..a71fa524 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -9,7 +9,7 @@ from ._util import ( Wavelet, - WaveletTransformReturn3d, + WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -158,7 +158,7 @@ def _construct_analysis_matrices( def __call__( self, input_signal: torch.Tensor - ) -> WaveletTransformReturn3d: + ) -> WaveletCoeffDetailDict: """Compute a separable 3d-boundary wavelet transform. Args: @@ -261,7 +261,7 @@ def _split_rec( split_list.append(coeff_dict) split_list.reverse() - result: WaveletTransformReturn3d = lll, *split_list + result: WaveletCoeffDetailDict = lll, *split_list if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) @@ -389,7 +389,7 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten return self._cat_coeff_recursive(done_dict) def __call__( - self, coefficients: WaveletTransformReturn3d + self, coefficients: WaveletCoeffDetailDict ) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 54da3358..e1b92311 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -10,7 +10,7 @@ import pywt import torch -from ._util import Wavelet, WaveletTransformReturn2d, WaveletTransformReturn3d, _as_wavelet +from ._util import Wavelet, WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, _as_wavelet from .constants import ExtendedBoundaryMode, OrthogonalizeMethod from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 @@ -384,7 +384,7 @@ def get_natural_order(self, level: int) -> list[str]: return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ - [torch.Tensor], WaveletTransformReturn2d, + [torch.Tensor], WaveletCoeffDetailTuple2d, ]: if self.mode == "boundary": shape = tuple(shape) @@ -413,7 +413,7 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletTransformReturn2d], torch.Tensor]: + def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: if self.mode == "boundary": shape = tuple(shape) if shape not in self.matrix_waverec2_dict.keys(): @@ -433,11 +433,11 @@ def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletTransformRetu def _transform_fsdict_to_tuple_func( self, - fs_dict_func: Callable[[torch.Tensor], WaveletTransformReturn3d], - ) -> Callable[[torch.Tensor], WaveletTransformReturn2d]: + fs_dict_func: Callable[[torch.Tensor], WaveletCoeffDetailDict], + ) -> Callable[[torch.Tensor], WaveletCoeffDetailTuple2d]: def _tuple_func( data: torch.Tensor, - ) -> WaveletTransformReturn2d: + ) -> WaveletCoeffDetailTuple2d: fs_dict_data = fs_dict_func(data) # assert for type checking assert len(fs_dict_data) == 2 @@ -448,10 +448,10 @@ def _tuple_func( def _transform_tuple_to_fsdict_func( self, - fsdict_func: Callable[[WaveletTransformReturn3d], torch.Tensor], - ) -> Callable[[WaveletTransformReturn2d], torch.Tensor]: + fsdict_func: Callable[[WaveletCoeffDetailDict], torch.Tensor], + ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: def _fsdict_func( - coeffs: WaveletTransformReturn2d + coeffs: WaveletCoeffDetailTuple2d ) -> torch.Tensor: # assert for type checking assert len(coeffs) == 2 diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index f5efa6b3..63d7e142 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -16,8 +16,8 @@ from ._util import ( Wavelet, - WaveletTransformReturn2d, - WaveletTransformReturn3d, + WaveletCoeffDetailTuple2d, + WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -116,7 +116,7 @@ def _separable_conv_wavedecn( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, -) -> WaveletTransformReturn3d: +) -> WaveletCoeffDetailDict: """Compute a multilevel separable padded wavelet analysis transform. Args: @@ -150,7 +150,7 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( - coeffs: WaveletTransformReturn3d, + coeffs: WaveletCoeffDetailDict, wavelet: Union[Wavelet, str], ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. @@ -187,7 +187,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> WaveletTransformReturn3d: +) -> WaveletCoeffDetailDict: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -257,7 +257,7 @@ def fswavedec3( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> WaveletTransformReturn3d: +) -> WaveletCoeffDetailDict: """Compute a fully separable 3D-padded analysis wavelet transform. Args: @@ -322,7 +322,7 @@ def fswavedec3( def fswaverec2( - coeffs: WaveletTransformReturn3d, + coeffs: WaveletCoeffDetailDict, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -389,7 +389,7 @@ def fswaverec2( def fswaverec3( - coeffs: WaveletTransformReturn3d, + coeffs: WaveletCoeffDetailDict, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: From c2dccd0e200dd42f11b141ceda86e5e3c328e985 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:09:24 +0200 Subject: [PATCH 12/40] tighten return type --- src/ptwt/packets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index e1b92311..3376b88e 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -183,7 +183,7 @@ def _get_wavedec( def _get_waverec( self, length: int, - ) -> Callable[[list[torch.Tensor]], torch.Tensor]: + ) -> Callable[[Sequence[torch.Tensor]], torch.Tensor]: if self.mode == "boundary": if length not in self._matrix_waverec_dict.keys(): self._matrix_waverec_dict[length] = MatrixWaverec( From 57c1ab7f147b26532b14b0b65728733ea2afdf63 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:16:36 +0200 Subject: [PATCH 13/40] Format --- src/ptwt/_util.py | 30 ++++++++++++++-------------- src/ptwt/conv_transform.py | 1 + src/ptwt/conv_transform_2.py | 4 +++- src/ptwt/conv_transform_3.py | 5 +---- src/ptwt/matmul_transform_2.py | 4 +--- src/ptwt/matmul_transform_3.py | 8 ++------ src/ptwt/packets.py | 18 +++++++++++------ src/ptwt/separable_conv_transform.py | 2 +- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index d8d597cf..951d340a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,13 +1,13 @@ """Utility methods to compute wavelet decompositions from a dataset.""" -from collections.abc import Sequence import typing +from collections.abc import Sequence from typing import Any, Callable, Optional, Protocol, Union, cast, overload -from typing_extensions import Unpack import numpy as np import pywt import torch +from typing_extensions import Unpack from .constants import OrthogonalizeMethod @@ -30,10 +30,13 @@ def __len__(self) -> int: """Return the number of filter coefficients.""" return len(self.dec_lo) + WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor] WaveletDetailDict = dict[str, torch.Tensor] -WaveletCoeffDetailTuple2d = tuple[torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]] +WaveletCoeffDetailTuple2d = tuple[ + torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] +] WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] @@ -96,7 +99,7 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch. cat_list.insert(0, signal[:padl].flip(0)) if padr > 0: cat_list.append(signal[-padr::].flip(0)) - return torch.cat(cat_list, axis=0) # type: ignore + return torch.cat(cat_list, dim=0) def _pad_symmetric( @@ -118,7 +121,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int """Fold unchanged leading dimensions into a single batch dimension. Args: - data ( torch.Tensor): The input data array. + data (torch.Tensor): The input data array. keep_no (int): The number of dimensions to keep. Returns: @@ -176,21 +179,19 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: def _map_result( data: WaveletCoeffDetailTuple2d, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailTuple2d: - ... +) -> WaveletCoeffDetailTuple2d: ... @overload def _map_result( data: WaveletCoeffDetailDict, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailDict: - ... +) -> WaveletCoeffDetailDict: ... def _map_result( data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], - function: Callable[[torch.Tensor], torch.Tensor] + function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]: return_tuple = isinstance(data, tuple) approx = function(data[0]) @@ -210,14 +211,13 @@ def _map_result( ) ) elif isinstance(element, dict): - new_dict = { - key: function(value) - for key, value in element.items() - } + new_dict = {key: function(value) for key, value in element.items()} result_lst.append(new_dict) else: raise AssertionError(f"Unexpected input type {type(element)}") return_val = approx, *result_lst - return_val = cast(Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val) + return_val = cast( + Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val + ) return return_val diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 66c9798e..77a39776 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -2,6 +2,7 @@ This module treats boundaries with edge-padding. """ + from collections.abc import Sequence from typing import Optional, Union, cast diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 229c910f..0b737129 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -97,7 +97,9 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list(coeffs: WaveletCoeffDetailTuple2d) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: +def _waverec2d_fold_channels_2d_list( + coeffs: WaveletCoeffDetailTuple2d, +) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index a9debf59..6883b00e 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -222,10 +222,7 @@ def _waverec3d_fold_channels_3d_list( fold_coeffs: list[dict[str, torch.Tensor]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) fold_coeffs = [ - { - key: _fold_axes(value, 3)[0] - for key, value in coeff.items() - } + {key: _fold_axes(value, 3)[0] for key, value in coeff.items()} for coeff in coeffs[1:] ] return (fold_approx_coeff, *fold_coeffs), ds diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index f1ff9157..fc596953 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -416,9 +416,7 @@ def _construct_analysis_matrices( current_width = current_width // 2 self.size_list.append((current_height, current_width)) - def __call__( - self, input_signal: torch.Tensor - ) -> WaveletCoeffDetailTuple2d: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index a71fa524..18736915 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -156,9 +156,7 @@ def _construct_analysis_matrices( ) self.size_list.append((current_depth, current_height, current_width)) - def __call__( - self, input_signal: torch.Tensor - ) -> WaveletCoeffDetailDict: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: """Compute a separable 3d-boundary wavelet transform. Args: @@ -388,9 +386,7 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten return cat_tensor return self._cat_coeff_recursive(done_dict) - def __call__( - self, coefficients: WaveletCoeffDetailDict - ) -> torch.Tensor: + def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 3376b88e..868ffe07 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -10,7 +10,12 @@ import pywt import torch -from ._util import Wavelet, WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, _as_wavelet +from ._util import ( + Wavelet, + WaveletCoeffDetailDict, + WaveletCoeffDetailTuple2d, + _as_wavelet, +) from .constants import ExtendedBoundaryMode, OrthogonalizeMethod from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 @@ -384,7 +389,8 @@ def get_natural_order(self, level: int) -> list[str]: return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ - [torch.Tensor], WaveletCoeffDetailTuple2d, + [torch.Tensor], + WaveletCoeffDetailTuple2d, ]: if self.mode == "boundary": shape = tuple(shape) @@ -413,7 +419,9 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: + def _get_waverec( + self, shape: tuple[int, ...] + ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: if self.mode == "boundary": shape = tuple(shape) if shape not in self.matrix_waverec2_dict.keys(): @@ -450,9 +458,7 @@ def _transform_tuple_to_fsdict_func( self, fsdict_func: Callable[[WaveletCoeffDetailDict], torch.Tensor], ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: - def _fsdict_func( - coeffs: WaveletCoeffDetailTuple2d - ) -> torch.Tensor: + def _fsdict_func(coeffs: WaveletCoeffDetailTuple2d) -> torch.Tensor: # assert for type checking assert len(coeffs) == 2 a, (h, v, d) = coeffs diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 63d7e142..bde51f71 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -16,8 +16,8 @@ from ._util import ( Wavelet, - WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, + WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, From 995029472050216c88927192cdef99433f9d6442 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:45:11 +0200 Subject: [PATCH 14/40] Address linter remarks --- src/ptwt/_stationary_transform.py | 3 ++- src/ptwt/_util.py | 1 - src/ptwt/conv_transform_2.py | 8 ++++---- src/ptwt/conv_transform_3.py | 5 +++-- src/ptwt/matmul_transform.py | 6 +++--- src/ptwt/matmul_transform_2.py | 7 +++---- src/ptwt/matmul_transform_3.py | 4 ++-- src/ptwt/packets.py | 2 +- src/ptwt/separable_conv_transform.py | 14 ++++++-------- 9 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index e188dd52..5577e0b5 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -115,7 +115,8 @@ def _iswt( """Inverts a 1d stationary wavelet transform. Args: - coeffs (Sequence[torch.Tensor]): The coefficients as computed by the swt function. + coeffs (Sequence[torch.Tensor]): The coefficients as computed + by the swt function. wavelet (Union[pywt.Wavelet, str]): The wavelet used for the forward transform. axis (int, optional): The axis the forward trasform was computed over. Defaults to -1. diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 951d340a..e7c5c032 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -193,7 +193,6 @@ def _map_result( data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]: - return_tuple = isinstance(data, tuple) approx = function(data[0]) result_lst: list[ Union[ diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 0b737129..c9e0f5be 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -4,7 +4,6 @@ torch.nn.functional.conv_transpose2d under the hood. """ -from collections.abc import Sequence from functools import partial from typing import Optional, Union, cast @@ -169,7 +168,7 @@ def wavedec2( last two. Defaults to (-2, -1). Returns: - WaveletTransformReturn2d: A tuple containing the wavelet coefficients. + WaveletCoeffDetailTuple2d: A tuple containing the wavelet coefficients. The coefficients are in pywt order. That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . @@ -248,8 +247,9 @@ def waverec2( or forward transform by running transposed convolutions. Args: - coeffs (WaveletTransformReturn2d): The wavelet coefficient tupl produced by wavedec2. - The coefficients must be in pywt order. That is:: + coeffs (WaveletCoeffDetailTuple2d): The wavelet coefficient tuple + produced by wavedec2. The coefficients must be in pywt order. + That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 6883b00e..48e5ab61 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -127,7 +127,7 @@ def wavedec3( instead of the last three. Defaults to (-3, -2, -1). Returns: - WaveletTransformReturn3d: A tuple with the lll coefficients and + WaveletCoeffDetailDict: A tuple with the lll coefficients and dictionaries with the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") @@ -236,7 +236,8 @@ def waverec3( """Reconstruct a signal from wavelet coefficients. Args: - coeffs (WaveletTransformReturn3d): The wavelet coefficient tuple produced by wavedec3. + coeffs (WaveletCoeffDetailDict): The wavelet coefficient tuple + produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axes (tuple[int, int, int]): Transform these axes instead of the diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 6bf12f33..ff951d9e 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -8,7 +8,7 @@ """ import sys -from collections import Sequence +from collections.abc import Sequence from typing import Optional, Union import numpy as np @@ -600,8 +600,8 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: """Run the synthesis or inverse matrix fwt. Args: - coefficients (Sequence[torch.Tensor]): The coefficients produced by the forward - transform. + coefficients (Sequence[torch.Tensor]): The coefficients produced + by the forward transform. Returns: torch.Tensor: The input signal reconstruction. diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index fc596953..5d4e4505 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -4,7 +4,6 @@ """ import sys -from collections.abc import Sequence from functools import partial from typing import Optional, Union, cast @@ -430,8 +429,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: This transform affects the last two dimensions. Returns: - (WaveletTransformReturn2d): The resulting coefficients per level are stored in - a pywt style tuple. The tuple is ordered as:: + (WaveletCoeffDetailTuple2d): The resulting coefficients per level + are stored in a pywt style tuple. The tuple is ordered as:: (ll, (lh, hl, hh), ...) @@ -729,7 +728,7 @@ def __call__( """Compute the inverse matrix 2d fast wavelet transform. Args: - coefficients (WaveletTransformReturn2d): The coefficient tuple as returned + coefficients (WaveletCoeffDetailTuple2d): The coefficient tuple as returned by the `MatrixWavedec2`-Object. Returns: diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 18736915..dc6673c7 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -167,7 +167,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: ValueError: If the input dimensions don't work. Returns: - WaveletTransformReturn3d: + WaveletCoeffDetailDict: A tuple with the approximation coefficients, and a coefficient dict for each scale. """ @@ -390,7 +390,7 @@ def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: - coefficients (WaveletTransformReturn3d): + coefficients (WaveletCoeffDetailDict): The output from MatrixWavedec3, consisting of a tuple of the approximation coefficients and a dict with the detail coefficients for each scale. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 868ffe07..8ec9e031 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from functools import partial from itertools import product -from typing import TYPE_CHECKING, Callable, Optional, Union, cast +from typing import TYPE_CHECKING, Callable, Optional, Union import numpy as np import pywt diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index bde51f71..0a6ba96c 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -11,13 +11,11 @@ from typing import Optional, Union import numpy as np -import pywt import torch from ._util import ( Wavelet, WaveletCoeffDetailDict, - WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -127,7 +125,7 @@ def _separable_conv_wavedecn( level (int): The desired decomposition level. Returns: - WaveletTransformReturn3d: A tuple with the approximation coefficients, + WaveletCoeffDetailDict: A tuple with the approximation coefficients, and a coefficient dict for each scale. """ result: list[dict[str, torch.Tensor]] = [] @@ -156,7 +154,7 @@ def _separable_conv_waverecn( """Separable n-dimensional wavelet synthesis transform. Args: - coeffs (WaveletTransformReturn3d): + coeffs (WaveletCoeffDetailDict): The output as produced by `_separable_conv_wavedecn`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used by `_separable_conv_wavedecn`. @@ -208,7 +206,7 @@ def fswavedec2( ValueError: If the data is not a batched 2D signal. Returns: - WaveletTransformReturn3d: + WaveletCoeffDetailDict: A tuple with the lll coefficients and dictionaries with the filter order strings:: @@ -277,7 +275,7 @@ def fswavedec3( ValueError: If the input is not a batched 3D signal. Returns: - WaveletTransformReturn3d: + WaveletCoeffDetailDict: A tuple with the lll coefficients and dictionaries with the filter order strings:: @@ -332,7 +330,7 @@ def fswaverec2( the hood. Args: - coeffs (WaveletTransformReturn3d): + coeffs (WaveletCoeffDetailDict): The wavelet coefficients as computed by `fswavedec2`. wavelet (Wavelet or str): The wavelet to use for the synthesis transform. @@ -396,7 +394,7 @@ def fswaverec3( """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeffs (WaveletTransformReturn3d): + coeffs (WaveletCoeffDetailDict): The wavelet coefficients as computed by `fswavedec3`. wavelet (Wavelet or str): The wavelet to use for the synthesis transform. From c9af10c362c824d8e3584e8a445040f8fb05b3fe Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:46:01 +0200 Subject: [PATCH 15/40] Adopt flake8 rules recommended by black project --- .flake8 | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 4aa3a6f7..9fdc12ec 100644 --- a/.flake8 +++ b/.flake8 @@ -23,6 +23,8 @@ ignore = # asserts are ok in test. S101 C901 +extend-select = B950 +extend-ignore = E501,E701,E704 exclude = .tox, .git, @@ -37,7 +39,7 @@ exclude = .eggs, data. src/ptwt/__init__.py -max-line-length = 90 +max-line-length = 80 max-complexity = 20 import-order-style = pycharm application-import-names = From df1884eb45a9d4403aa5ce8fcb9bcc353a86fcc2 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 01:03:13 +0200 Subject: [PATCH 16/40] Remove matplotlib requirement --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 624b60c7..653580d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,6 @@ install_requires = torch scipy>=1.10 pooch - matplotlib numpy pytest nox From 414492d70a173f08a1496bb82954ae26dbbdf194 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 02:32:25 +0200 Subject: [PATCH 17/40] Minor typing improvement --- src/ptwt/conv_transform.py | 32 +++++++++++++------------------- src/ptwt/conv_transform_2.py | 4 ++-- src/ptwt/conv_transform_3.py | 5 ++--- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 77a39776..b471ce92 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -4,13 +4,14 @@ """ from collections.abc import Sequence -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch from ._util import ( Wavelet, + WaveletCoeffDetailTuple2d, _as_wavelet, _fold_axes, _get_len, @@ -159,7 +160,7 @@ def _fwt_pad( # convert pywt to pytorch convention. if mode is None: - mode = cast(BoundaryMode, "reflect") + mode = "reflect" pytorch_mode = _translate_boundary_strings(mode) padr, padl = _get_pad(data.shape[-1], _get_len(wavelet)) @@ -171,33 +172,26 @@ def _fwt_pad( def _flatten_2d_coeff_lst( - coeff_lst_2d: Sequence[ - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ], + coeff_lst_2d: WaveletCoeffDetailTuple2d, flatten_tensors: bool = True, ) -> list[torch.Tensor]: """Flattens a sequence of tensor tuples into a single list. Args: - coeff_lst_2d (Sequence): A pywt-style coefficient sequence of torch tensors. + coeff_lst_2d (WaveletCoeffDetailTuple2d): A pywt-style + coefficient tuple of torch tensors. flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True. Returns: list: A single 1-d list with all original elements. """ - flat_coeff_lst = [] - for coeff in coeff_lst_2d: - if isinstance(coeff, tuple): - for c in coeff: - if flatten_tensors: - flat_coeff_lst.append(c.flatten()) - else: - flat_coeff_lst.append(c) - else: - if flatten_tensors: - flat_coeff_lst.append(coeff.flatten()) - else: - flat_coeff_lst.append(coeff) + + def _process_tensor(coeff: torch.Tensor) -> torch.Tensor: + return coeff.flatten() if flatten_tensors else coeff + + flat_coeff_lst = [_process_tensor(coeff_lst_2d[0])] + for coeff_tuple in coeff_lst_2d[1:]: + flat_coeff_lst.extend(map(_process_tensor, coeff_tuple)) return flat_coeff_lst diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index c9e0f5be..9acf3457 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -5,7 +5,7 @@ """ from functools import partial -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -82,7 +82,7 @@ def _fwt_pad2( """ if mode is None: - mode = cast(BoundaryMode, "reflect") + mode = "reflect" pytorch_mode = _translate_boundary_strings(mode) wavelet = _as_wavelet(wavelet) padb, padt = _get_pad(data.shape[-2], _get_len(wavelet)) diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 48e5ab61..9787d2ce 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -3,9 +3,8 @@ The functions here are based on torch.nn.functional.conv3d and it's transpose. """ -from collections.abc import Sequence from functools import partial -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -292,7 +291,7 @@ def waverec3( filt_len = rec_lo.shape[-1] rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi) - coeff_dicts = cast(Sequence[dict[str, torch.Tensor]], coeffs[1:]) + coeff_dicts = coeffs[1:] for c_pos, coeff_dict in enumerate(coeff_dicts): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( From f495accd86da35cfea382b1eaed1d51685e15cf0 Mon Sep 17 00:00:00 2001 From: Felix Blanke <45953206+felixblanke@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:40:50 +0200 Subject: [PATCH 18/40] Change error type to ValueError at input validation Co-authored-by: Charles Tapley Hoyt --- src/ptwt/_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index e7c5c032..7fd48a0a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -213,7 +213,7 @@ def _map_result( new_dict = {key: function(value) for key, value in element.items()} result_lst.append(new_dict) else: - raise AssertionError(f"Unexpected input type {type(element)}") + raise ValueError(f"Unexpected input type {type(element)}") return_val = approx, *result_lst return_val = cast( From acd87690a9694ee1dd4b22a29bec3054ef4e5de5 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 11 Jun 2024 17:54:00 +0200 Subject: [PATCH 19/40] Improve type hints --- src/ptwt/_stationary_transform.py | 19 +++++---- src/ptwt/_util.py | 7 ++-- src/ptwt/continuous_transform.py | 46 ++++++++++------------ src/ptwt/conv_transform.py | 27 ++++++------- src/ptwt/conv_transform_2.py | 11 +++--- src/ptwt/conv_transform_3.py | 21 +++++----- src/ptwt/matmul_transform.py | 28 ++++++-------- src/ptwt/matmul_transform_2.py | 53 +++++++++++-------------- src/ptwt/matmul_transform_3.py | 17 +++++--- src/ptwt/packets.py | 14 +++---- src/ptwt/separable_conv_transform.py | 58 ++++++++++++++-------------- src/ptwt/sparse_math.py | 54 ++++++++++++-------------- src/ptwt/wavelets_learnable.py | 29 +++++++------- 13 files changed, 176 insertions(+), 208 deletions(-) diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index 5577e0b5..fb4f3798 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -24,13 +24,12 @@ def _swt( """Compute a multilevel 1d stationary wavelet transform. Args: - data (torch.Tensor): The input data of shape [batch_size, time]. + data (torch.Tensor): The input data of shape ``[batch_size, time]``. wavelet (Union[Wavelet, str]): The wavelet to use. - level (Optional[int], optional): The number of levels to compute + level (int, optional): The number of levels to compute. Returns: - list[torch.Tensor]: Same as wavedec. - Equivalent to pywt.swt with trim_approx=True. + Same as wavedec. Equivalent to pywt.swt with trim_approx=True. Raises: ValueError: Is the axis argument is not an integer. @@ -84,14 +83,14 @@ def _conv_transpose_dedilate( Args: conv_res (torch.Tensor): The dilated coeffcients - of shape [batch, 2, length]. + of shape ``[batch, 2, length]``. rec_filt (torch.Tensor): The reconstruction filter pair - of shape [1, 2, filter_length]. + of shape ``[1, 2, filter_length]``. dilation (int): The dilation factor. length (int): The signal length. Returns: - torch.Tensor: The deconvolution result. + The deconvolution result. """ to_conv_t_list = [ conv_res[..., fl : (fl + dilation * rec_filt.shape[-1]) : dilation] @@ -121,11 +120,11 @@ def _iswt( axis (int, optional): The axis the forward trasform was computed over. Defaults to -1. + Returns: + A reconstruction of the original swt input. + Raises: ValueError: If the axis argument is not an integer. - - Returns: - torch.Tensor: A reconstruction of the original swt input. """ if axis != -1: swap = [] diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 7fd48a0a..ecac29f3 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -48,8 +48,7 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: pywt wavelet compatible object or a valid pywt wavelet name string. Returns: - Wavelet: the input wavelet object or the pywt wavelet object described by the - input str. + The input wavelet object or the pywt wavelet object described by the input str. """ if isinstance(wavelet, str): return pywt.Wavelet(wavelet) @@ -125,8 +124,8 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int keep_no (int): The number of dimensions to keep. Returns: - tuple[torch.Tensor, list[int]]: - The folded result array, and the shape of the original input. + A tuple (result_tensor, input_shape) where result_tensor is the + folded result array, and input_shape the shape of the original input. """ dshape = list(data.shape) return ( diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py index e2fb8f89..0c32f3d4 100644 --- a/src/ptwt/continuous_transform.py +++ b/src/ptwt/continuous_transform.py @@ -34,7 +34,7 @@ def cwt( https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py Args: - data (torch.Tensor): The input tensor of shape [batch_size, time]. + data (torch.Tensor): The input tensor of shape ``[batch_size, time]``. scales (torch.Tensor or np.array): The wavelet scales to use. One can use ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine @@ -50,9 +50,9 @@ def cwt( ValueError: If a scale is too small for the input signal. Returns: - tuple[torch.Tensor, np.ndarray]: The first tuple-element contains - the transformation matrix of shape [scales, batch, time]. - The second element contains an array with frequency information. + A tuple (out_tensor, frequencies). The first tuple-element contains + the transformation matrix of shape ``[scales, batch, time]``. + The second element contains an array with frequency information. Example: >>> import torch, ptwt @@ -165,27 +165,23 @@ def _integrate_wavelet( https://github.com/PyWavelets/pywt/blob/cef09e7f419aaf4c39b9f778bdc2d54b32fd7337/pywt/_functions.py#L60 - Parameters - ---------- - wavelet: Wavelet instance or str - Wavelet to integrate. If a string, should be the name of a wavelet. - precision : int, optional - Precision that will be used for wavelet function - approximation computed with the wavefun(level=precision) - Wavelet's method (default: 8). - Returns - ------- - [int_psi, x] : - for orthogonal wavelets - [int_psi_d, int_psi_r, x] : - for other wavelets - Examples - -------- - >>> from pywt import Wavelet, _integrate_wavelet - >>> wavelet1 = Wavelet('db2') - >>> [int_psi, x] = _integrate_wavelet(wavelet1, precision=5) - >>> wavelet2 = Wavelet('bior1.3') - >>> [int_psi_d, int_psi_r, x] = _integrate_wavelet(wavelet2, precision=5) + Args: + wavelet (Wavelet instance or str): Wavelet to integrate. + If a string, should be the name of a wavelet. + precision (int): Precision that will be used for wavelet function + approximation computed with the wavefun(level=precision) + Wavelet's method. Defaults to 8. + + Returns: + A tuple (int_psi, x) for orthogonal wavelets; + for other wavelets, a tuple (int_psi_d, int_psi_r, x) is returned instead. + + Example: + >>> from pywt import Wavelet, _integrate_wavelet + >>> wavelet1 = Wavelet('db2') + >>> [int_psi, x] = _integrate_wavelet(wavelet1, precision=5) + >>> wavelet2 = Wavelet('bior1.3') + >>> [int_psi_d, int_psi_r, x] = _integrate_wavelet(wavelet2, precision=5) """ def _integrate( diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index b471ce92..9d06dee9 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -54,9 +54,8 @@ def _get_filter_tensors( computation. Default: torch.float32. Returns: - tuple: Tuple containing the four filter tensors - dec_lo, dec_hi, rec_lo, rec_hi - + A tuple (dec_lo, dec_hi, rec_lo, rec_hi) containing + the four filter tensors """ wavelet = _as_wavelet(wavelet) device = torch.device(device) @@ -80,10 +79,8 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: filt_len (int): The size of the used filter. Returns: - Tuple: The first entry specifies how many numbers - to attach on the right. The second - entry covers the left side. - + A tuple (padr, padl). The first entry specifies how many numbers + to attach on the right. The second entry covers the left side. """ # pad to ensure we see all filter positions and # for pywt compatability. @@ -117,7 +114,6 @@ def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str: Raises: ValueError: If the padding mode is not supported. - """ if pywt_mode == "constant": return "replicate" @@ -153,8 +149,7 @@ def _fwt_pad( Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. Returns: - torch.Tensor: A PyTorch tensor with the padded input data - + A PyTorch tensor with the padded input data """ wavelet = _as_wavelet(wavelet) @@ -183,7 +178,7 @@ def _flatten_2d_coeff_lst( flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True. Returns: - list: A single 1-d list with all original elements. + A single 1-d list with all original elements. """ def _process_tensor(coeff: torch.Tensor) -> torch.Tensor: @@ -220,9 +215,9 @@ def _preprocess_tensor_dec1d( data (torch.Tensor): An input tensor of any shape. Returns: - tuple[torch.Tensor, Union[list[int], None]]: - A data tensor of shape [new_batch, 1, to_process] - and the original shape, if the shape has changed. + A tuple (data, ds) where data is a data tensor of shape + [new_batch, 1, to_process]. `ds` contains the original shape + if the shape has changed. Otherwise, ds is None. """ ds = None if len(data.shape) == 1: @@ -293,7 +288,7 @@ def wavedec( Returns: - list: A list:: + A list:: [cA_s, cD_s, cD_s-1, …, cD2, cD1] @@ -366,7 +361,7 @@ def waverec( axis (int): Transform this axis instead of the last one. Defaults to -1. Returns: - torch.Tensor: The reconstructed signal. + The reconstructed signal tensor. Raises: ValueError: If the dtype of the coeffs tensor is unsupported or if the diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 9acf3457..a63f19bd 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -43,7 +43,7 @@ def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor: hi (torch.Tensor): High-pass input filter Returns: - torch.Tensor: Stacked 2d-filters of dimension + Stacked 2d-filters of dimension [filt_no, 1, height, width]. @@ -168,8 +168,8 @@ def wavedec2( last two. Defaults to (-2, -1). Returns: - WaveletCoeffDetailTuple2d: A tuple containing the wavelet coefficients. - The coefficients are in pywt order. That is:: + A tuple containing the wavelet coefficients. The coefficients are in pywt order. + That is:: [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . @@ -261,9 +261,8 @@ def waverec2( last two. Defaults to (-2, -1). Returns: - torch.Tensor: - The reconstructed signal of shape ``[batch, height, width]`` or - ``[batch, channel, height, width]`` depending on the input to `wavedec2`. + The reconstructed signal tensor of shape ``[batch, height, width]`` or + ``[batch, channel, height, width]`` depending on the input to `wavedec2`. Raises: ValueError: If coeffs is not in a shape as returned from wavedec2 or diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 9787d2ce..69144d81 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -42,12 +42,11 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor: hi (torch.Tensor): High-pass input filter Returns: - torch.Tensor: Stacked 3d filters of dimension:: + Stacked 3d filters of dimension:: [8, 1, length, height, width]. The four filters are ordered ll, lh, hl, hh. - """ dim_size = lo.shape[-1] size = [dim_size] * 3 @@ -81,7 +80,6 @@ def _fwt_pad3( Returns: The padded output tensor. - """ pytorch_mode = _translate_boundary_strings(mode) @@ -126,13 +124,14 @@ def wavedec3( instead of the last three. Defaults to (-3, -2, -1). Returns: - WaveletCoeffDetailDict: A tuple with the lll coefficients and - dictionaries with the filter order strings:: + A tuple with the lll coefficients and for each scale a dictionary + containing the detail coefficients. The dictionaries use + the filter order strings:: - ("aad", "ada", "add", "daa", "dad", "dda", "ddd") + ("aad", "ada", "add", "daa", "dad", "dda", "ddd") - as keys. With 'a' for the low pass or approximation filter and - 'd' for the high-pass or detail filter. + as keys. 'a' denotes the low pass or approximation filter and + 'd' the high-pass or detail filter. Raises: ValueError: If the input has fewer than three dimensions or @@ -143,7 +142,6 @@ def wavedec3( >>> import ptwt, torch >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") - """ if tuple(axes) != (-3, -2, -1): if len(axes) != 3: @@ -243,8 +241,8 @@ def waverec3( last three. Defaults to (-3, -2, -1). Returns: - torch.Tensor: The reconstructed four-dimensional signal of shape - [batch, depth, height, width]. + The reconstructed four-dimensional signal tensor of shape + ``[batch, depth, height, width]``. Raises: ValueError: If coeffs is not in a shape as returned from wavedec3 or @@ -256,7 +254,6 @@ def waverec3( >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") >>> reconstruction = ptwt.waverec3(transformed, "haar") - """ if tuple(axes) != (-3, -2, -1): if len(axes) != 3: diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index ff951d9e..d0f8f575 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -57,7 +57,7 @@ def _construct_a( or torch.float64. Defaults to torch.float64. Returns: - torch.Tensor: The sparse raw analysis matrix. + The sparse raw analysis matrix. """ wavelet = _as_wavelet(wavelet) dec_lo, dec_hi, _, _ = _get_filter_tensors( @@ -94,7 +94,7 @@ def _construct_s( or torch.float64. Defaults to torch.float64. Returns: - torch.Tensor: The raw sparse synthesis matrix. + The raw sparse synthesis matrix. """ wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( @@ -120,7 +120,7 @@ def _get_to_orthogonalize(matrix: torch.Tensor, filt_len: int) -> torch.Tensor: filt_len (int): The number of entries we would expect per row. Returns: - torch.Tensor: The row indices with too few entries. + The row indices with too few entries. """ unique, count = torch.unique_consecutive( matrix.coalesce().indices()[0, :], return_counts=True @@ -142,7 +142,7 @@ def orthogonalize( Defaults to qr. Returns: - torch.Tensor: Orthogonal sparse transformation matrix. + Orthogonal sparse transformation matrix. Raises: ValueError: If an invalid orthogonalization method is given @@ -232,21 +232,18 @@ def __init__( @property def sparse_fwt_operator(self) -> torch.Tensor: - """Return the sparse transformation operator. + """The sparse transformation operator. If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed as a single matrix multiply. - The operation torch.sparse.mm(sparse_fwt_operator, data.T) + The operation ``torch.sparse.mm(sparse_fwt_operator, data.T)`` computes a batched fwt. This property exists to make the operator matrix transparent. Calling the object will handle odd-length inputs properly. - Returns: - torch.Tensor: The sparse operator matrix. - Raises: NotImplementedError: if padding had to be used in the creation of the transformation matrices. @@ -330,7 +327,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: another axis. Returns: - list[torch.Tensor]: A list with the coefficients for each scale. + A list with the coefficient tensor for each scale. Raises: ValueError: If the decomposition level is not a positive integer @@ -416,7 +413,7 @@ def construct_boundary_a( dtype: Choose float32 or float64. Returns: - torch.Tensor: The sparse analysis matrix. + The sparse analysis matrix. """ wavelet = _as_wavelet(wavelet) a_full = _construct_a(wavelet, length, dtype=dtype, device=device) @@ -445,7 +442,7 @@ def construct_boundary_s( Defaults to torch.float64. Returns: - torch.Tensor: The sparse synthesis matrix. + The sparse synthesis matrix. """ wavelet = _as_wavelet(wavelet) s_full = _construct_s(wavelet, length, dtype=dtype, device=device) @@ -514,7 +511,7 @@ def __init__( @property def sparse_ifwt_operator(self) -> torch.Tensor: - """Return the sparse transformation operator. + """The sparse transformation operator. If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed @@ -527,9 +524,6 @@ def sparse_ifwt_operator(self) -> torch.Tensor: This functionality is mainly here to make the operator-matrix transparent. Calling the object handles padding for odd inputs. - Returns: - torch.Tensor: The sparse operator matrix. - Raises: NotImplementedError: if padding had to be used in the creation of the transformation matrices. @@ -604,7 +598,7 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: by the forward transform. Returns: - torch.Tensor: The input signal reconstruction. + The input signal reconstruction. Raises: ValueError: If the decomposition level is not a positive integer or if the diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 5d4e4505..55e2754d 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -66,14 +66,12 @@ def _construct_a_2( Defaults to 'sameshift'. Returns: - torch.Tensor: A sparse fwt analysis matrix. - The matrices are ordered a,h,v,d or - ll, lh, hl, hh. + A sparse fwt analysis matrix. + The matrices are ordered a, h, v, d or ll, lh, hl, hh. Note: The constructed matrix is NOT necessarily orthogonal. In most cases, construct_boundary_a2d should be used instead. - """ dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=False, device=device, dtype=dtype @@ -118,7 +116,7 @@ def _construct_s_2( Defaults to 'sameshift'. Returns: - [torch.Tensor]: The generated fast wavelet synthesis matrix. + The generated fast wavelet synthesis matrix. """ wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( @@ -167,8 +165,7 @@ def construct_boundary_a2( Defaults to torch.float64. Returns: - torch.Tensor: A sparse fwt matrix, with orthogonalized boundary - wavelets. + A sparse fwt matrix, with orthogonalized boundary wavelets. """ wavelet = _as_wavelet(wavelet) a = _construct_a_2(wavelet, height, width, device, dtype=dtype, mode="sameshift") @@ -200,8 +197,7 @@ def construct_boundary_s2( Defaults to torch.float64. Returns: - torch.Tensor: The synthesis matrix, used to compute the - inverse fast wavelet transform. + The synthesis matrix, used to compute the inverse fast wavelet transform. """ wavelet = _as_wavelet(wavelet) s = _construct_s_2(wavelet, height, width, device, dtype=dtype) @@ -225,14 +221,14 @@ def _matrix_pad_2(height: int, width: int) -> tuple[int, int, tuple[bool, bool]] class MatrixWavedec2(BaseMatrixWaveDec): """Experimental sparse matrix 2d wavelet transform. - For a completely pad-free transform, - input images are expected to be divisible by two. - For multiscale transforms all intermediate - scale dimensions should be divisible - by two, i.e. 128, 128 -> 64, 64 -> 32, 32 would work - well for a level three transform. - In this case multiplication with the `sparse_fwt_operator` - property is equivalent. + For a completely pad-free transform, + input images are expected to be divisible by two. + For multiscale transforms all intermediate + scale dimensions should be divisible + by two, i.e. ``128, 128 -> 64, 64 -> 32, 32`` would work + well for a level three transform. + In this case multiplication with the `sparse_fwt_operator` + property is equivalent. Note: Constructing the sparse fwt-matrix is expensive. @@ -250,7 +246,6 @@ class MatrixWavedec2(BaseMatrixWaveDec): >>> pt_face = torch.tensor(face).permute([2, 0, 1]) >>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2) >>> mat_coeff = matrixfwt(pt_face) - """ def __init__( @@ -314,11 +309,11 @@ def __init__( def sparse_fwt_operator(self) -> torch.Tensor: """Compute the operator matrix for padding-free cases. - This property exists to make the transformation matrix available. - To benefit from code handling odd-length levels call the object. + This property exists to make the transformation matrix available. + To benefit from code handling odd-length levels call the object. Returns: - torch.Tensor: The sparse 2d-fwt operator matrix. + The sparse 2d-fwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding @@ -429,8 +424,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: This transform affects the last two dimensions. Returns: - (WaveletCoeffDetailTuple2d): The resulting coefficients per level - are stored in a pywt style tuple. The tuple is ordered as:: + The resulting coefficients per level are stored in a pywt style tuple. + The tuple is ordered as:: (ll, (lh, hl, hh), ...) @@ -628,7 +623,7 @@ def sparse_ifwt_operator(self) -> torch.Tensor: """Compute the ifwt operator matrix for pad-free cases. Returns: - torch.Tensor: The sparse 2d ifwt operator matrix. + The sparse 2d ifwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding @@ -732,12 +727,10 @@ def __call__( by the `MatrixWavedec2`-Object. Returns: - torch.Tensor: The original signal reconstruction. - For example of shape - ``[batch_size, height, width]`` or - ``[batch_size, channels, height, width]`` - depending on the input to the forward transform. - and the value of the `axis` argument. + The original signal reconstruction. For example of shape + ``[batch_size, height, width]`` or ``[batch_size, channels, height, width]`` + depending on the input to the forward transform and the value + of the `axis` argument. Raises: ValueError: If the decomposition level is not a positive integer or if the diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index dc6673c7..0892955a 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -161,15 +161,20 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: Args: input_signal (torch.Tensor): An input signal. For example - of shape [batch_size, depth, height, width]. + of shape ``[batch_size, depth, height, width]``. + + Returns: + A tuple with the lll coefficients and for each scale a dictionary + containing the detail coefficients. The dictionaries use + the filter order strings:: + + ("aad", "ada", "add", "daa", "dad", "dda", "ddd") + + as keys. 'a' denotes the low pass or approximation filter and + 'd' the high-pass or detail filter. Raises: ValueError: If the input dimensions don't work. - - Returns: - WaveletCoeffDetailDict: - A tuple with the approximation coefficients, - and a coefficient dict for each scale. """ if self.axes != (-3, -2, -1): input_signal = _swap_axes(input_signal, list(self.axes)) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 8ec9e031..6a0b9cad 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -40,7 +40,7 @@ def _wpfreq(fs: float, level: int) -> list[float]: level (int): The decomposition level. Returns: - list[float]: The frequency bins of the packets in frequency order. + The frequency bins of the packets in frequency order. """ n = np.array(range(int(np.power(2.0, level)))) freqs = (fs / 2.0) * (n / (np.power(2.0, level))) @@ -205,7 +205,7 @@ def get_level(self, level: int) -> list[str]: level (int): The depth of the tree. Returns: - list: A list with the paths to each node. + A list with the paths to each node. """ return self._get_graycode_order(level) @@ -238,7 +238,7 @@ def __getitem__(self, key: str) -> torch.Tensor: of the following chars: 'a', 'd'. Returns: - torch.Tensor: The accessed wavelet packet coefficients. + The accessed wavelet packet coefficients. Raises: ValueError: If the wavelet packet tree is not initialized. @@ -324,7 +324,7 @@ def transform( Args: data (torch.tensor): The input data tensor - of shape [batch_size, height, width] + of shape ``[batch_size, height, width]``. maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. @@ -384,7 +384,7 @@ def get_natural_order(self, level: int) -> list[str]: level (int): The decomposition level. Returns: - list: A list with the filter order strings. + A list with the filter order strings. """ return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] @@ -491,7 +491,7 @@ def __getitem__(self, key: str) -> torch.Tensor: of the following chars: 'a', 'h', 'v', 'd'. Returns: - torch.Tensor: The accessed wavelet packet coefficients. + The accessed wavelet packet coefficients. Raises: ValueError: If the wavelet packet tree is not initialized. @@ -520,7 +520,7 @@ def get_freq_order(level: int) -> list[list[tuple[str, ...]]]: level (int): The number of decomposition scales. Returns: - list: A list with the tree nodes in frequency order. + A list with the tree nodes in frequency order. Note: Adapted from: diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 0a6ba96c..8da099b1 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -44,7 +44,7 @@ def _separable_conv_dwtn_( All but the first axes are transformed. Args: - input_arg (torch.Tensor): Tensor of shape [batch, data_1, ... data_n]. + input_arg (torch.Tensor): Tensor of shape ``[batch, data_1, ... data_n]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. mode : The padding mode. The following methods are supported:: @@ -81,7 +81,7 @@ def _separable_conv_idwtn( _separable_conv_dwtn_ . Returns: - torch.Tensor: A reconstruction of the original signal. + A reconstruction of the original signal. """ done_dict = {} a_initial_keys = list(filter(lambda x: x[0] == "a", in_dict.keys())) @@ -118,15 +118,18 @@ def _separable_conv_wavedecn( """Compute a multilevel separable padded wavelet analysis transform. Args: - input (torch.Tensor): A tensor i.e. of shape [batch,axis_1, ... axis_n]. + input (torch.Tensor): A tensor i.e. of shape ``[batch,axis_1, ... axis_n]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. mode : The desired padding mode. level (int): The desired decomposition level. Returns: - WaveletCoeffDetailDict: A tuple with the approximation coefficients, - and a coefficient dict for each scale. + A tuple with the approximation coefficients and + for each scale a dictionary containing the detail coefficients. + The dictionaries use a string of length n as keys with + 'a' denoting the low pass or approximation filter and + 'd' the high-pass or detail filter. """ result: list[dict[str, torch.Tensor]] = [] approx = input @@ -160,7 +163,7 @@ def _separable_conv_waverecn( the name of a pywt wavelet, as used by `_separable_conv_wavedecn`. Returns: - torch.Tensor: The reconstruction of the original signal. + The reconstruction of the original signal. Raises: ValueError: If the coeffs is not structured as expected. @@ -202,26 +205,24 @@ def fswavedec2( axes ([int, int]): The axes we want to transform, defaults to (-2, -1). - Raises: - ValueError: If the data is not a batched 2D signal. - Returns: - WaveletCoeffDetailDict: - A tuple with the lll coefficients and dictionaries - with the filter order strings:: + A tuple with the ll coefficients and for each scale a dictionary + containing the detail coefficients. The dictionaries use + the filter order strings:: ("ad", "da", "dd") - as keys. With a for the low pass or approximation filter and - d for the high-pass or detail filter. + as keys. 'a' denotes the low pass or approximation filter and + 'd' the high-pass or detail filter. + Raises: + ValueError: If the data is not a batched 2D signal. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) - """ if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") @@ -259,7 +260,7 @@ def fswavedec3( """Compute a fully separable 3D-padded analysis wavelet transform. Args: - data (torch.Tensor): An input signal of shape [batch, depth, height, width]. + data (torch.Tensor): An input signal of shape ``[batch, depth, height, width]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. @@ -271,18 +272,19 @@ def fswavedec3( axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). - Raises: - ValueError: If the input is not a batched 3D signal. - Returns: - WaveletCoeffDetailDict: - A tuple with the lll coefficients and dictionaries - with the filter order strings:: + A tuple with the lll coefficients and for each scale a dictionary + containing the detail coefficients. The dictionaries use + the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") - as keys. With a for the low pass or approximation filter and - d for the high-pass or detail filter. + as keys. 'a' denotes the low pass or approximation filter and + 'd' the high-pass or detail filter. + + Raises: + ValueError: If the input is not a batched 3D signal. + Example: >>> import torch @@ -338,8 +340,7 @@ def fswaverec2( axes instead of the last two. Defaults to (-2, -1). Returns: - torch.Tensor: A reconstruction of the signal encoded in the - wavelet coefficients. + A reconstruction of the signal encoded in the wavelet coefficients. Raises: ValueError: If the axes argument is not a tuple of two integers. @@ -350,7 +351,6 @@ def fswaverec2( >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") - """ if tuple(axes) != (-2, -1): if len(axes) != 2: @@ -402,8 +402,7 @@ def fswaverec3( instead of the last three. Defaults to (-3, -2, -1). Returns: - torch.Tensor: A reconstruction of the signal encoded in the - wavelet coefficients. + A reconstruction of the signal encoded in the wavelet coefficients. Raises: ValueError: If the axes argument is not a tuple with @@ -415,7 +414,6 @@ def fswaverec3( >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") - """ if tuple(axes) != (-3, -2, -1): if len(axes) != 3: diff --git a/src/ptwt/sparse_math.py b/src/ptwt/sparse_math.py index f7721cb6..c74800e7 100644 --- a/src/ptwt/sparse_math.py +++ b/src/ptwt/sparse_math.py @@ -16,11 +16,11 @@ def _dense_kron( by memory on my machine. Args: - sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n]. - sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q]. + sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape ``[m, n]``. + sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape ``[p, q]``. Returns: - torch.Tensor: The resulting [mp, nq] tensor. + The resulting ``[mp, nq]`` tensor. """ return torch.kron( @@ -39,11 +39,11 @@ def sparse_kron( https://github.com/scipy/scipy/blob/v1.7.1/scipy/sparse/construct.py#L274-L357 Args: - sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n]. - sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q]. + sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape ``[m, n]``. + sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape ``[p, q]``. Returns: - torch.Tensor: The resulting [mp, nq] tensor. + The resulting tensor of shape ``[mp, nq]``. """ assert sparse_tensor_a.device == sparse_tensor_b.device @@ -102,8 +102,7 @@ def cat_sparse_identity_matrix( The length up to which the diagonal should be elongated. Returns: - torch.Tensor: Square [input, eye] matrix - of size [new_length, new_length] + Square ``[input, eye]`` matrix of size ``[new_length, new_length]`` """ # assert square matrix. assert ( @@ -153,7 +152,7 @@ def sparse_diag( cols (int): The number of columns in the final matrix. Returns: - torch.Tensor: A sparse matrix with a shifted diagonal. + A sparse matrix with a shifted diagonal. """ diag_indices = torch.stack( @@ -200,8 +199,7 @@ def sparse_replace_row( row (torch.Tensor): The row to insert into the sparse matrix. Returns: - torch.Tensor: A sparse matrix, with the new row inserted at - row_index. + A sparse matrix, with the new row inserted at row_index. """ matrix = matrix.coalesce() assert ( @@ -246,7 +244,7 @@ def _orth_by_qr( rows_to_orthogonalize (torch.Tensor): The matrix rows, which need work. Returns: - torch.Tensor: The corrected sparse matrix. + The corrected sparse matrix. """ selection_indices = torch.stack( [ @@ -305,7 +303,7 @@ def _orth_by_gram_schmidt( to_orthogonalize (torch.Tensor): The matrix rows, which need work. Returns: - torch.Tensor: The orthogonalized sparse matrix. + The orthogonalized sparse matrix. """ done: list[int] = [] # loop over the rows we want to orthogonalize @@ -347,11 +345,10 @@ def construct_conv_matrix( Defaults to valid. Returns: - torch.Tensor: The sparse convolution tensor. + The sparse convolution tensor. Raises: ValueError: If the padding is not 'full', 'same' or 'valid'. - """ filter_length = len(filter) @@ -400,7 +397,7 @@ def construct_conv2d_matrix( a call to scipy.signal.convolve2d and a reshape. Args: - filter (torch.tensor): A filter of shape [height, width] + filter (torch.tensor): A filter of shape ``[height, width]`` to convolve with. input_rows (int): The number of rows in the input matrix. input_columns (int): The number of columns in the input matrix. @@ -410,7 +407,7 @@ def construct_conv2d_matrix( to save memory. Defaults to True. Returns: - torch.Tensor: A sparse convolution matrix. + A sparse convolution matrix. Raises: ValueError: If the padding mode is neither full, same or valid. @@ -476,7 +473,7 @@ def construct_strided_conv_matrix( Defaults to 'valid'. Returns: - torch.Tensor: The strided sparse convolution matrix. + The strided sparse convolution matrix. """ conv_matrix = construct_conv_matrix(filter, input_rows, mode=mode) if mode == "sameshift": @@ -513,12 +510,11 @@ def construct_strided_conv2d_matrix( mode : The convolution type. Defaults to 'full'. Sameshift starts at 1 instead of 0. - Raises: - ValueError: Raised if an unknown convolution string is - provided. - Returns: - torch.Tensor: The sparse convolution tensor. + The sparse convolution tensor. + + Raises: + ValueError: Raised if an unknown convolution string is provided. """ filter_shape = filter.shape @@ -573,11 +569,11 @@ def batch_mm(matrix: torch.Tensor, matrix_batch: torch.Tensor) -> torch.Tensor: The former can be dense or sparse. Args: - matrix (torch.Tensor): Sparse or dense matrix, size (m, n). - matrix_batch (torch.Tensor): Batched dense matrices, size (b, n, k). + matrix (torch.Tensor): Sparse or dense matrix, of shape ``[m, n]``. + matrix_batch (torch.Tensor): Batched dense matrices, of shape ``[b, n, k]``. - Returns - torch.Tensor: The batched matrix-matrix product, size (b, m, k). + Returns: + The batched matrix-matrix product of shape ``[b, m, k]``. Raises: ValueError: If the matrices cannot be multiplied due to incompatible matrix @@ -598,12 +594,12 @@ def _batch_dim_mm( """Multiply batch_tensor with matrix along the dimensions specified in dim. Args: - matrix (torch.Tensor): A matrix of shape [m, n] + matrix (torch.Tensor): A matrix of shape ``[m, n]`` batch_tensor (torch.Tensor): A tensor with a selected dim of length n. dim (int): The position of the desired dimension. Returns: - torch.Tensor: The multiplication result. + The multiplication result. """ dim_length = batch_tensor.shape[dim] permuted_tensor = batch_tensor.transpose(dim, -1) diff --git a/src/ptwt/wavelets_learnable.py b/src/ptwt/wavelets_learnable.py index 3e0fa118..c28b541b 100644 --- a/src/ptwt/wavelets_learnable.py +++ b/src/ptwt/wavelets_learnable.py @@ -44,14 +44,13 @@ def pf_alias_cancellation_loss( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the product filter-alias cancellation loss. - See: Strang+Nguyen 105: F0(z) = H1(-z); F1(z) = -H0(-z) + See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign convention from 0 to N see Strang overview on the back of the cover. Returns: - list: The numerical value of the alias cancellation loss, - as well as both loss components for analysis. - + The numerical value of the alias cancellation loss, + as well as both loss components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) @@ -81,12 +80,11 @@ def alias_cancellation_loss( Implementation of the ac-loss as described on page 104 of Strang+Nguyen. - F0(z)H0(-z) + F1(z)H1(-z) = 0 + $$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$ Returns: - list: The numerical value of the alias cancellation loss, - as well as both loss components for analysis. - + The numerical value of the alias cancellation loss, + as well as both loss components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) @@ -122,9 +120,8 @@ def perfect_reconstruction_loss( """Return the perfect reconstruction loss. Returns: - list: The numerical value of the alias cancellation loss, - as well as both intermediate values for analysis. - + The numerical value of the alias cancellation loss, + as well as both intermediate values for analysis. """ # Strang 107: Assuming alias cancellation holds: # P(z) = F(z)H(z) @@ -196,7 +193,7 @@ def __init__( def filter_bank( self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Return all filters a a tuple.""" + """All filters a a tuple.""" return self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi # def parameters(self): @@ -210,7 +207,7 @@ def product_filter_loss(self) -> torch.Tensor: """Get only the product filter loss. Returns: - torch.Tensor: The loss scalar. + The loss scalar. """ return self.perfect_reconstruction_loss()[0] + self.alias_cancellation_loss()[0] @@ -218,7 +215,7 @@ def wavelet_loss(self) -> torch.Tensor: """Return the sum of all loss terms. Returns: - torch.Tensor: The loss scalar. + The loss scalar. """ return self.product_filter_loss() @@ -251,7 +248,7 @@ def rec_lo_orthogonality_loss(self) -> torch.Tensor: trough convolution. Returns: - torch.Tensor: A tensor with the orthogonality constraint value. + A tensor with the orthogonality constraint value. """ filt_len = self.dec_lo.shape[-1] pad_dec_lo = torch.cat( @@ -285,7 +282,7 @@ def filt_bank_orthogonality_loss(self) -> torch.Tensor: is presented. A measurement is implemented below. Returns: - torch.Tensor: A tensor with the orthogonality constraint value. + A tensor with the orthogonality constraint value. """ eq0 = self.dec_lo - self.rec_lo.flip(-1) eq1 = self.dec_hi - self.rec_hi.flip(-1) From 6d601325d77ad6ccbd0fd577bf08b80bdaf86c28 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 11 Jun 2024 17:54:18 +0200 Subject: [PATCH 20/40] Fix axis typehint and add to docstr --- src/ptwt/_stationary_transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index fb4f3798..dfe3e170 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -19,7 +19,7 @@ def _swt( data: torch.Tensor, wavelet: Union[Wavelet, str], level: Optional[int] = None, - axis: Optional[int] = -1, + axis: int = -1, ) -> list[torch.Tensor]: """Compute a multilevel 1d stationary wavelet transform. @@ -27,6 +27,7 @@ def _swt( data (torch.Tensor): The input data of shape ``[batch_size, time]``. wavelet (Union[Wavelet, str]): The wavelet to use. level (int, optional): The number of levels to compute. + axis (int): The axis to transform along. Defaults to the last axis. Returns: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. From e035881e942bf2daf46110d211e87decd80dbdab Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 00:48:46 +0200 Subject: [PATCH 21/40] Move type aliases into public 'constants' module and add docstr. --- src/ptwt/__init__.py | 1 + src/ptwt/_util.py | 16 ++----- src/ptwt/constants.py | 68 ++++++++++++++++++++++++++++ src/ptwt/conv_transform.py | 3 +- src/ptwt/conv_transform_2.py | 3 +- src/ptwt/conv_transform_3.py | 3 +- src/ptwt/matmul_transform_2.py | 3 +- src/ptwt/matmul_transform_3.py | 3 +- src/ptwt/packets.py | 8 ++-- src/ptwt/separable_conv_transform.py | 3 +- 10 files changed, 84 insertions(+), 27 deletions(-) diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index 34808704..d0ad9418 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -1,6 +1,7 @@ """Differentiable and gpu enabled fast wavelet transforms in PyTorch.""" from ._util import Wavelet +from .constants import WaveletCoeffDetailDict, WaveletCoeffDetailTuple2d from .continuous_transform import cwt from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index ecac29f3..b29aa348 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -7,9 +7,12 @@ import numpy as np import pywt import torch -from typing_extensions import Unpack -from .constants import OrthogonalizeMethod +from .constants import ( + OrthogonalizeMethod, + WaveletCoeffDetailDict, + WaveletCoeffDetailTuple2d, +) class Wavelet(Protocol): @@ -31,15 +34,6 @@ def __len__(self) -> int: return len(self.dec_lo) -WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor] -WaveletDetailDict = dict[str, torch.Tensor] - -WaveletCoeffDetailTuple2d = tuple[ - torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] -] -WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] - - def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: """Ensure the input argument to be a pywt wavelet compatible object. diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 7d9e92c8..35b55c1b 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -2,11 +2,18 @@ from typing import Literal, Union +import torch.Tensor +from typing_extensions import TypeAlias, Unpack + __all__ = [ "BoundaryMode", "ExtendedBoundaryMode", "PaddingMode", "OrthogonalizeMethod", + "WaveletDetailTuple2d", + "WaveletCoeffDetailTuple2d", + "WaveletCoeffDetailDict", + "WaveletDetailDict", ] BoundaryMode = Literal["constant", "zero", "reflect", "periodic", "symmetric"] @@ -36,3 +43,64 @@ Choose 'gramschmidt' if 'qr' runs out of memory. """ + + +WaveletDetailTuple2d: TypeAlias = tuple[torch.Tensor, torch.Tensor, torch.Tensor] +"""Detail coefficients of a 2d wavelet transform for a given level. + +This is a type alias for a tuple ``(H, V, D)`` of detail coefficient tensors +where ``H`` denotes horizontal, ``V`` vertical and ``D`` diagonal coefficients. + +Alias of ``tuple[torch.Tensor, torch.Tensor, torch.Tensor]`` +""" + + +WaveletDetailDict: TypeAlias = dict[str, torch.Tensor] +"""Type alias for a dict containing detail coefficient for a given level. + +Thus type alias represents the detail coefficient tensors of a given level for +a wavelet transform in :math:`N` dimensions as the values of a dictionary. +Its keys are a string of length :math:`N` describing the detail coefficient +by the applied filter for each axis where 'a' denotes the low pass +or approximation filter and 'd' the high-pass or detail filter. +For a 3d transform, the dictionary thus uses the keys:: + +("aad", "ada", "add", "daa", "dad", "dda", "ddd") + +Alias of ``dict[str, torch.Tensor]`` +""" + + +WaveletCoeffDetailTuple2d: TypeAlias = tuple[ + torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] +] +"""Type alias for 2d wavelet transform results. + +This type alias represents the result of a 2d wavelet transform +with :math:`L` levels as a tuple ``(A, T1, T2, ...)`` of length :math:`L + 1` +where ``A`` denotes a tensor of approximation coefficients and +``Tl`` is a tuple of detail coefficients for level ``l``, +see :data:`ptwt.constants.WaveletDetailTuple2d`. + +Note that this type always contains an approximation coefficient tensor but does not +necesseraily contain any detail coefficients. + +Alias of ``tuple[torch.Tensor, *tuple[WaveletDetailTuple2d, ...]]`` +""" + +WaveletCoeffDetailDict: TypeAlias = tuple[ + torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]] +] +"""Type alias for wavelet transform results in any dimension. + +This type alias represents the result of a Nd wavelet transform +with :math:`L` levels as a tuple ``(A, D1, D2, ...)`` of length :math:`L + 1` +where ``A`` denotes a tensor of approximation coefficients and +``Dl`` is a dictionary of detail coefficients for level ``l``, +see :data:`ptwt.constants.WaveletDetailDict`. + +Note that this type always contains an approximation coefficient tensor but does not +necesseraily contain any detail coefficients. + +Alias of ``tuple[torch.Tensor, *tuple[WaveletDetailDict, ...]]`` +""" diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 9d06dee9..fac6d294 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -11,7 +11,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailTuple2d, _as_wavelet, _fold_axes, _get_len, @@ -19,7 +18,7 @@ _pad_symmetric, _unfold_axes, ) -from .constants import BoundaryMode +from .constants import BoundaryMode, WaveletCoeffDetailTuple2d def _create_tensor( diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index a63f19bd..782b6fd4 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -12,7 +12,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -26,7 +25,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode +from .constants import BoundaryMode, WaveletCoeffDetailTuple2d from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 69144d81..da77476c 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -11,7 +11,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -25,7 +24,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode +from .constants import BoundaryMode, WaveletCoeffDetailDict from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 55e2754d..9d96558f 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -12,7 +12,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -23,7 +22,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import OrthogonalizeMethod, PaddingMode +from .constants import OrthogonalizeMethod, PaddingMode, WaveletCoeffDetailTuple2d from .conv_transform import _get_filter_tensors from .conv_transform_2 import ( _construct_2d_filt, diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 0892955a..7e85e601 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -9,7 +9,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -21,7 +20,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import OrthogonalizeMethod +from .constants import OrthogonalizeMethod, WaveletCoeffDetailDict from .conv_transform_3 import _waverec3d_fold_channels_3d_list from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 6a0b9cad..dd4efb7f 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -10,13 +10,13 @@ import pywt import torch -from ._util import ( - Wavelet, +from ._util import Wavelet, _as_wavelet +from .constants import ( + ExtendedBoundaryMode, + OrthogonalizeMethod, WaveletCoeffDetailDict, WaveletCoeffDetailTuple2d, - _as_wavelet, ) -from .constants import ExtendedBoundaryMode, OrthogonalizeMethod from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 from .matmul_transform import MatrixWavedec, MatrixWaverec diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 8da099b1..afb13014 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -15,7 +15,6 @@ from ._util import ( Wavelet, - WaveletCoeffDetailDict, _as_wavelet, _check_axes_argument, _check_if_tensor, @@ -26,7 +25,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode +from .constants import BoundaryMode, WaveletCoeffDetailDict from .conv_transform import wavedec, waverec from .conv_transform_2 import _preprocess_tensor_dec2d From 5051f13ab40f88eb40aba194d61daf53ee17b896 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 00:53:19 +0200 Subject: [PATCH 22/40] Fixate some type aliases to not be resolved in docs. Fixates some type aliases such that the API documentation shows the type alias name instead of its resolved value. For this feature to work we need to import the 'annotations' feature from __future__ --- docs/conf.py | 6 ++++++ src/ptwt/_util.py | 2 ++ src/ptwt/conv_transform.py | 2 ++ src/ptwt/conv_transform_2.py | 2 ++ src/ptwt/conv_transform_3.py | 2 ++ src/ptwt/matmul_transform_2.py | 2 ++ src/ptwt/matmul_transform_3.py | 2 ++ src/ptwt/packets.py | 2 ++ src/ptwt/separable_conv_transform.py | 2 ++ 9 files changed, 22 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 2e872f84..63902e48 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,3 +82,9 @@ # numbered figures numfig = True + +autodoc_type_aliases = { + "WaveletCoeffDetailTuple2d": "ptwt.constants.WaveletCoeffDetailTuple2d", + "WaveletCoeffDetailDict": "ptwt.constants.WaveletCoeffDetailDict", + "BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec", +} diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index b29aa348..cc742054 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,5 +1,7 @@ """Utility methods to compute wavelet decompositions from a dataset.""" +from __future__ import annotations + import typing from collections.abc import Sequence from typing import Any, Callable, Optional, Protocol, Union, cast, overload diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index fac6d294..78ba71b9 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -3,6 +3,8 @@ This module treats boundaries with edge-padding. """ +from __future__ import annotations + from collections.abc import Sequence from typing import Optional, Union diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 782b6fd4..ff1379ae 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -4,6 +4,8 @@ torch.nn.functional.conv_transpose2d under the hood. """ +from __future__ import annotations + from functools import partial from typing import Optional, Union diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index da77476c..78d0812d 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -3,6 +3,8 @@ The functions here are based on torch.nn.functional.conv3d and it's transpose. """ +from __future__ import annotations + from functools import partial from typing import Optional, Union diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 9d96558f..43a39b55 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -3,6 +3,8 @@ This module uses boundary filters to minimize padding. """ +from __future__ import annotations + import sys from functools import partial from typing import Optional, Union, cast diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 7e85e601..5035b4b4 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -1,5 +1,7 @@ """Implement 3D separable boundary transforms.""" +from __future__ import annotations + import sys from functools import partial from typing import NamedTuple, Optional, Union diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index dd4efb7f..8a017841 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -1,5 +1,7 @@ """Compute analysis wavelet packet representations.""" +from __future__ import annotations + import collections from collections.abc import Sequence from functools import partial diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index afb13014..ed4f3c33 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -7,6 +7,8 @@ using torch.nn.functional.conv1d and it's transpose. """ +from __future__ import annotations + from functools import partial from typing import Optional, Union From 1c6b4ea987dd9ffb27a5655964294b06a536493e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 00:55:33 +0200 Subject: [PATCH 23/40] Improve docstrings --- src/ptwt/constants.py | 24 +++++++++++++++--------- src/ptwt/conv_transform.py | 2 ++ src/ptwt/conv_transform_2.py | 11 ++++++++--- src/ptwt/conv_transform_3.py | 12 +++++++++--- src/ptwt/matmul_transform.py | 6 +++++- src/ptwt/matmul_transform_2.py | 4 ++++ src/ptwt/matmul_transform_3.py | 9 +++++++-- src/ptwt/packets.py | 12 ++++++++++-- src/ptwt/separable_conv_transform.py | 19 +++++++++++-------- 9 files changed, 71 insertions(+), 28 deletions(-) diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 35b55c1b..829c47cf 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -18,16 +18,22 @@ BoundaryMode = Literal["constant", "zero", "reflect", "periodic", "symmetric"] """ -This is a type literal for the way of padding. +This is a type literal for the way of padding used at boundaries. -- Refection padding mirrors samples along the border. -- Zero padding pads zeros. -- Constant padding replicates border values. -- Periodic padding cyclically repeats samples. -- Symmetric padding mirrors samples along the border +- Refection padding mirrors samples along the border (``reflect``) +- Zero padding pads zeros (``zero``) +- Constant padding replicates border values (``constant``) +- Periodic padding cyclically repeats samples (``periodic``) +- Symmetric padding mirrors samples along the border (``symmetric``) """ ExtendedBoundaryMode = Union[Literal["boundary"], BoundaryMode] +""" +This is a type literal for the way of handling signal boundaries. + +This is either a form of padding (see :data:`ptwt.constants.BoundaryMode` +for padding options) or ``boundary`` to use boundary wavelets. +""" PaddingMode = Literal["full", "valid", "same", "sameshift"] """ @@ -38,10 +44,10 @@ """ The method for orthogonalizing a matrix. -1. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry. -2. 'gramschmidt' option is sparse, memory efficient, and slow. +1. ``qr`` relies on pytorch's dense QR implementation, it is fast but memory hungry. +2. ``gramschmidt`` option is sparse, memory efficient, and slow. -Choose 'gramschmidt' if 'qr' runs out of memory. +Choose ``gramschmidt`` if ``qr`` runs out of memory. """ diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 78ba71b9..2e981ed7 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -359,6 +359,8 @@ def waverec( coeffs (Sequence): The wavelet coefficient sequence produced by wavedec. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axis (int): Transform this axis instead of the last one. Defaults to -1. Returns: diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index ff1379ae..866d8d31 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -74,6 +74,8 @@ def _fwt_pad2( data (torch.Tensor): Input data with 4 dimensions. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : The desired padding mode for extending the signal along the edges. Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. @@ -156,10 +158,11 @@ def wavedec2( By default 2d inputs are interpreted as ``[height, width]``, 3d inputs are interpreted as ``[batch_size, height, width]``. 4d inputs are interpreted as ``[batch_size, channels, height, width]``. - the ``axis`` argument allows other interpretations. + The ``axes`` argument allows other interpretations. wavelet (Wavelet or str): A pywt wavelet compatible object or - the name of a pywt wavelet. Refer to the output of - ``pywt.wavelist(kind="discrete")`` for a list of possible choices. + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : The desired padding mode for extending the signal along the edges. Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. @@ -258,6 +261,8 @@ def waverec2( and 'D' diagonal coefficients. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 78d0812d..a7387b85 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -75,6 +75,8 @@ def _fwt_pad3( data (torch.Tensor): Input data with 4 dimensions. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. @@ -113,9 +115,11 @@ def wavedec3( Args: data (torch.Tensor): The input data. For example of shape - [batch_size, length, height, width] - wavelet (Union[Wavelet, str]): The wavelet to transform with. - ``pywt.wavelist(kind='discrete')`` lists possible choices. + ``[batch_size, length, height, width]`` + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : The desired padding mode for extending the signal along the edges. Defaults to "zero". See :data:`ptwt.constants.BoundaryMode`. @@ -238,6 +242,8 @@ def waverec3( produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index d0f8f575..5498920a 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -188,11 +188,13 @@ def __init__( axis: Optional[int] = -1, boundary: OrthogonalizeMethod = "qr", ) -> None: - """Create a matrix-fwt object. + """A sparse matrix fast wavelet transform object. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. @@ -478,6 +480,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axis (int): The axis transformed by the original decomposition defaults to -1 or the last axis. boundary : The method used for boundary filter treatment. diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 43a39b55..aa784249 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -262,6 +262,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. @@ -577,6 +579,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (int, int): The axes transformed by waverec2. Defaults to (-2, -1). boundary : The method used for boundary filter treatment. diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 5035b4b4..4d243795 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -68,8 +68,11 @@ def __init__( this object transforms the last three dimensions. Args: - wavelet (Union[Wavelet, str]): The wavelet to use. - level (Optional[int]): The desired decomposition level. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. + level (int, optional): The desired decomposition level. Defaults to None. boundary: The matrix orthogonalization method. Defaults to "qr". @@ -292,6 +295,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). boundary : The method used for boundary filter treatment. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 8a017841..426eb029 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -74,6 +74,8 @@ def __init__( Use the ``axis`` argument to choose another dimension. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : The desired padding method. If you select 'boundary', the sparse matrix backend will be used. Defaults to 'reflect'. maxlevel (int, optional): Value is passed on to `transform`. @@ -237,7 +239,8 @@ def __getitem__(self, key: str) -> torch.Tensor: Args: key (str): The key of the accessed coefficients. The string may only consist - of the following chars: 'a', 'd'. + of the chars 'a' and 'd' where 'a' denotes the low pass or + approximation filter and 'd' the high-pass or detail filter. Returns: The accessed wavelet packet coefficients. @@ -287,6 +290,8 @@ def __init__( a decomposition. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. mode : A string indicating the desired padding mode. If you select 'boundary', the sparse matrix backend is used. Defaults to 'reflect' @@ -490,7 +495,10 @@ def __getitem__(self, key: str) -> torch.Tensor: Args: key (str): The key of the accessed coefficients. The string may only consist - of the following chars: 'a', 'h', 'v', 'd'. + of the following chars: 'a', 'h', 'v', 'd' + The chars correspond to the selected coefficients for a level + where 'a' denotes the approximation coefficients and + 'h' horizontal, 'v' vertical and 'd' diagonal details coefficients. Returns: The accessed wavelet packet coefficients. diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index ed4f3c33..89e9b725 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -78,8 +78,7 @@ def _separable_conv_idwtn( in_dict (dict[str, torch.Tensor]): The dictionary produced by _separable_conv_dwtn_ . wavelet (Wavelet or str): A pywt wavelet compatible object or - the name of a pywt wavelet, as used by - _separable_conv_dwtn_ . + the name of a pywt wavelet, as used by ``_separable_conv_dwtn_``. Returns: A reconstruction of the original signal. @@ -161,7 +160,7 @@ def _separable_conv_waverecn( coeffs (WaveletCoeffDetailDict): The output as produced by `_separable_conv_wavedecn`. wavelet (Wavelet or str): A pywt wavelet compatible object or - the name of a pywt wavelet, as used by `_separable_conv_wavedecn`. + the name of a pywt wavelet, as used by ``_separable_conv_wavedecn``. Returns: The reconstruction of the original signal. @@ -264,7 +263,7 @@ def fswavedec3( data (torch.Tensor): An input signal of shape ``[batch, depth, height, width]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of - ``pywt.wavelist(kind="discrete")`` for a list of possible choices. + ``pywt.wavelist(kind="discrete")`` for possible choices. mode : The desired padding mode for extending the signal along the edges. Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. @@ -335,8 +334,10 @@ def fswaverec2( Args: coeffs (WaveletCoeffDetailDict): The wavelet coefficients as computed by `fswavedec2`. - wavelet (Wavelet or str): The wavelet to use for the - synthesis transform. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (tuple[int, int]): Compute the transform over these axes instead of the last two. Defaults to (-2, -1). @@ -397,8 +398,10 @@ def fswaverec3( Args: coeffs (WaveletCoeffDetailDict): The wavelet coefficients as computed by `fswavedec3`. - wavelet (Wavelet or str): The wavelet to use for the - synthesis transform. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. axes (tuple[int, int, int]): Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1). From 23ed9d829182d63182d5b23293ee621caeecbdf3 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:05:16 +0200 Subject: [PATCH 24/40] Replace redundant docstr info with refers to type alias --- src/ptwt/conv_transform_2.py | 17 +++----------- src/ptwt/conv_transform_3.py | 12 +++------- src/ptwt/matmul_transform.py | 23 +++++++----------- src/ptwt/matmul_transform_2.py | 35 ++++++++++------------------ src/ptwt/matmul_transform_3.py | 26 +++++++-------------- src/ptwt/packets.py | 11 +++++---- src/ptwt/separable_conv_transform.py | 16 ++++++++----- 7 files changed, 50 insertions(+), 90 deletions(-) diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 866d8d31..563e248b 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -172,13 +172,8 @@ def wavedec2( last two. Defaults to (-2, -1). Returns: - A tuple containing the wavelet coefficients. The coefficients are in pywt order. - That is:: - - [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . - - 'A' denotes approximation, 'H' horizontal, 'V' vertical - and 'D' diagonal coefficients. + A tuple containing the wavelet coefficients in pywt order, + see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. Raises: ValueError: If the dimensionality or the dtype of the input data tensor @@ -252,13 +247,7 @@ def waverec2( Args: coeffs (WaveletCoeffDetailTuple2d): The wavelet coefficient tuple - produced by wavedec2. The coefficients must be in pywt order. - That is:: - - [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] . - - 'A' denotes approximation, 'H' horizontal, 'V' vertical, - and 'D' diagonal coefficients. + produced by wavedec2. See :data:`ptwt.constants.WaveletCoeffDetailTuple2d` wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index a7387b85..38c114df 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -129,14 +129,8 @@ def wavedec3( instead of the last three. Defaults to (-3, -2, -1). Returns: - A tuple with the lll coefficients and for each scale a dictionary - containing the detail coefficients. The dictionaries use - the filter order strings:: - - ("aad", "ada", "add", "daa", "dad", "dda", "ddd") - - as keys. 'a' denotes the low pass or approximation filter and - 'd' the high-pass or detail filter. + A tuple containing the wavelet coefficients, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. Raises: ValueError: If the input has fewer than three dimensions or @@ -239,7 +233,7 @@ def waverec3( Args: coeffs (WaveletCoeffDetailDict): The wavelet coefficient tuple - produced by wavedec3. + produced by wavedec3, see :data:`ptwt.constants.WaveletCoeffDetailDict`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 5498920a..30b64cfc 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -200,11 +200,8 @@ def __init__( None. axis (int, optional): The axis we would like to transform. Defaults to -1. - boundary : The method used for boundary filter treatment. - Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr - implementation, it is fast but memory hungry. The 'gramschmidt' - option is sparse, memory efficient, and slow. Choose 'gramschmidt' if - 'qr' runs out of memory. Defaults to 'qr'. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. @@ -407,9 +404,8 @@ def construct_boundary_a( wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. length (int): The number of entries in the input signal. - boundary : A string indicating the desired boundary treatment. - Possible options are qr and gramschmidt. Defaults to - qr. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. device: Where to place the matrix. Choose cpu or cuda. Defaults to cpu. dtype: Choose float32 or float64. @@ -438,8 +434,8 @@ def construct_boundary_s( length (int): The number of entries in the input signal. device (torch.device): Where to place the matrix. Choose cpu or cuda. Defaults to cpu. - boundary : A string indicating the desired boundary treatment. - Possible options are qr and gramschmidt. Defaults to qr. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. dtype: Choose torch.float32 or torch.float64. Defaults to torch.float64. @@ -484,11 +480,8 @@ def __init__( for possible choices. axis (int): The axis transformed by the original decomposition defaults to -1 or the last axis. - boundary : The method used for boundary filter treatment. - Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr - implementation, it is fast but memory hungry. The 'gramschmidt' option - is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs - out of memory. Defaults to 'qr'. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index aa784249..d9424ea9 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -160,8 +160,8 @@ def construct_boundary_a2( Should be divisible by two. device (torch.device): Where to place the matrix. Either on the CPU or GPU. - boundary : The method to use for matrix orthogonalization. - Choose "qr" or "gramschmidt". Defaults to "qr". + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. dtype (torch.dtype, optional): The desired data type for the matrix. Defaults to torch.float64. @@ -191,8 +191,8 @@ def construct_boundary_s2( height (int): The original height of the input matrix. width (int): The width of the original input matrix. device (torch.device): Choose CPU or GPU. - boundary : The method to use for matrix orthogonalization. - Choose qr or gramschmidt. Defaults to qr. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. dtype (torch.dtype, optional): The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64. @@ -269,13 +269,8 @@ def __init__( None. axes (int, int): A tuple with the axes to transform. Defaults to (-2, -1). - boundary : The method used for boundary filter treatment. - Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's - dense qr implementation, it is fast but memory hungry. - The 'gramschmidt' option is sparse, memory efficient, - and slow. - Choose 'gramschmidt' if 'qr' runs out of memory. - Defaults to 'qr'. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable @@ -427,12 +422,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: This transform affects the last two dimensions. Returns: - The resulting coefficients per level are stored in a pywt style tuple. - The tuple is ordered as:: - - (ll, (lh, hl, hh), ...) - - with 'l' for low-pass and 'h' for high-pass filters. + The resulting coefficients per level are stored in a pywt style tuple, + see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. Raises: ValueError: If the decomposition level is not a positive integer @@ -583,11 +574,8 @@ def __init__( for possible choices. axes (int, int): The axes transformed by waverec2. Defaults to (-2, -1). - boundary : The method used for boundary filter treatment. - Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr - implementation, it is fast but memory hungry. The 'gramschmidt' option - is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs - out of memory. Defaults to 'qr'. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- @@ -729,7 +717,8 @@ def __call__( Args: coefficients (WaveletCoeffDetailTuple2d): The coefficient tuple as returned - by the `MatrixWavedec2`-Object. + by the `MatrixWavedec2` object, + see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. Returns: The original signal reconstruction. For example of shape diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 4d243795..e4d4f0af 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -74,8 +74,8 @@ def __init__( for possible choices. level (int, optional): The desired decomposition level. Defaults to None. - boundary: The matrix orthogonalization method. - Defaults to "qr". + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the chosen orthogonalization method @@ -168,14 +168,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: of shape ``[batch_size, depth, height, width]``. Returns: - A tuple with the lll coefficients and for each scale a dictionary - containing the detail coefficients. The dictionaries use - the filter order strings:: - - ("aad", "ada", "add", "daa", "dad", "dda", "ddd") - - as keys. 'a' denotes the low pass or approximation filter and - 'd' the high-pass or detail filter. + The resulting coefficients for each level are stored in a tuple, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. Raises: ValueError: If the input dimensions don't work. @@ -299,11 +293,8 @@ def __init__( for possible choices. axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). - boundary : The method used for boundary filter treatment. - Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr - implementation, it is fast but memory hungry. The 'gramschmidt' option - is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs - out of memory. Defaults to 'qr'. + boundary : The method used for boundary filter treatment, + see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. @@ -402,9 +393,8 @@ def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor: Args: coefficients (WaveletCoeffDetailDict): - The output from MatrixWavedec3, consisting of a tuple - of the approximation coefficients and a dict with the - detail coefficients for each scale. + The output from the `MatrixWavedec3` object, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. Returns: torch.Tensor: A reconstruction of the original signal. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 426eb029..fee4666a 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -83,8 +83,9 @@ def __init__( is determined from the input data shape. Defaults to None. axis (int): The axis to transform. Defaults to -1. boundary_orthogonalization : The orthogonalization method - to use. Only used if `mode` equals 'boundary'. Choose from - 'qr' or 'gramschmidt'. Defaults to 'qr'. + to use in the sparse matrix backend, + see :data:`ptwt.constants.OrthogonalizeMethod`. + Only used if `mode` equals 'boundary'. Defaults to 'qr'. Example: >>> import torch, pywt, ptwt @@ -301,9 +302,9 @@ def __init__( axes ([int, int], optional): The tensor axes that should be transformed. Defaults to (-2, -1). boundary_orthogonalization : The orthogonalization method - to use in the sparse matrix backend. Only used if `mode` - equals 'boundary'. Choose from 'qr' or 'gramschmidt'. - Defaults to 'qr'. + to use in the sparse matrix backend, + see :data:`ptwt.constants.OrthogonalizeMethod`. + Only used if `mode` equals 'boundary'. Defaults to 'qr'. separable (bool): If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False. diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 89e9b725..ff169ec8 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -207,8 +207,9 @@ def fswavedec2( Returns: A tuple with the ll coefficients and for each scale a dictionary - containing the detail coefficients. The dictionaries use - the filter order strings:: + containing the detail coefficients, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. + The dictionaries use the filter order strings:: ("ad", "da", "dd") @@ -274,8 +275,9 @@ def fswavedec3( Returns: A tuple with the lll coefficients and for each scale a dictionary - containing the detail coefficients. The dictionaries use - the filter order strings:: + containing the detail coefficients, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. + The dictionaries use the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") @@ -333,7 +335,8 @@ def fswaverec2( Args: coeffs (WaveletCoeffDetailDict): - The wavelet coefficients as computed by `fswavedec2`. + The wavelet coefficients as computed by `fswavedec2`, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` @@ -397,7 +400,8 @@ def fswaverec3( Args: coeffs (WaveletCoeffDetailDict): - The wavelet coefficients as computed by `fswavedec3`. + The wavelet coefficients as computed by `fswavedec3`, + see :data:`ptwt.constants.WaveletCoeffDetailDict`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` From 1139876c38cd8a7bced2bbf3ddc1ecc1fdc171ca Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:05:55 +0200 Subject: [PATCH 25/40] Add some special member funcs to docs --- docs/ptwt.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/ptwt.rst b/docs/ptwt.rst index 7f197d24..222b747d 100644 --- a/docs/ptwt.rst +++ b/docs/ptwt.rst @@ -32,6 +32,7 @@ ptwt.packets module .. automodule:: ptwt.packets :members: + :special-members: __getitem__ :undoc-members: :show-inheritance: @@ -58,6 +59,7 @@ ptwt.matmul\_transform module .. automodule:: ptwt.matmul_transform :members: + :special-members: __call__ :undoc-members: :show-inheritance: @@ -66,6 +68,7 @@ ptwt.matmul\_transform\_2 module .. automodule:: ptwt.matmul_transform_2 :members: + :special-members: __call__ :undoc-members: :show-inheritance: @@ -74,6 +77,7 @@ ptwt.matmul\_transform\_3 module .. automodule:: ptwt.matmul_transform_3 :members: + :special-members: __call__ :undoc-members: :show-inheritance: From dd10c825ab0ff448ccd38a9bb4de7d48d7c9b9b6 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:06:15 +0200 Subject: [PATCH 26/40] Fix favicon and move version module in API --- docs/conf.py | 2 -- docs/ptwt.rst | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 63902e48..e0aaa4c3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,8 +70,6 @@ html_favicon = "_static/favicon.ico" html_logo = "_static/shannon.png" -html_favicon = "favicon/favicon.ico" - # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". diff --git a/docs/ptwt.rst b/docs/ptwt.rst index 222b747d..4dad3a21 100644 --- a/docs/ptwt.rst +++ b/docs/ptwt.rst @@ -90,14 +90,6 @@ ptwt.sparse\_math module :undoc-members: :show-inheritance: -ptwt.version module -------------------- - -.. automodule:: ptwt.version - :members: - :undoc-members: - :show-inheritance: - ptwt.wavelets\_learnable module ------------------------------- @@ -112,3 +104,11 @@ ptwt.constants :members: :undoc-members: :show-inheritance: + +ptwt.version module +------------------- + +.. automodule:: ptwt.version + :members: + :undoc-members: + :show-inheritance: From 0f3d01a9655a675b36f016da702778c9711081b8 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:06:30 +0200 Subject: [PATCH 27/40] Fix import --- src/ptwt/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 829c47cf..ecb9f548 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -2,7 +2,7 @@ from typing import Literal, Union -import torch.Tensor +import torch from typing_extensions import TypeAlias, Unpack __all__ = [ From cd4c1315580abee51268f9c8bfd60644e51e5fc4 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:28:58 +0200 Subject: [PATCH 28/40] Change detail coeff tuple to namedtuple --- src/ptwt/constants.py | 16 +++++++++------- src/ptwt/conv_transform_2.py | 8 +++++--- src/ptwt/matmul_transform_2.py | 13 +++++++++---- src/ptwt/packets.py | 8 ++++++-- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index ecb9f548..9adffced 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -1,6 +1,6 @@ """Constants and types used throughout the PyTorch Wavelet Toolbox.""" -from typing import Literal, Union +from typing import Literal, NamedTuple, Union import torch from typing_extensions import TypeAlias, Unpack @@ -51,14 +51,16 @@ """ -WaveletDetailTuple2d: TypeAlias = tuple[torch.Tensor, torch.Tensor, torch.Tensor] -"""Detail coefficients of a 2d wavelet transform for a given level. +class WaveletDetailTuple2d(NamedTuple): + """Detail coefficients of a 2d wavelet transform for a given level. -This is a type alias for a tuple ``(H, V, D)`` of detail coefficient tensors -where ``H`` denotes horizontal, ``V`` vertical and ``D`` diagonal coefficients. + This is a type alias for a named tuple ``(H, V, D)`` of detail coefficient tensors + where ``H`` denotes horizontal, ``V`` vertical and ``D`` diagonal coefficients. + """ -Alias of ``tuple[torch.Tensor, torch.Tensor, torch.Tensor]`` -""" + horizontal: torch.Tensor + vertical: torch.Tensor + diagonal: torch.Tensor WaveletDetailDict: TypeAlias = dict[str, torch.Tensor] diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 563e248b..4a6a48f3 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -27,7 +27,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffDetailTuple2d +from .constants import BoundaryMode, WaveletCoeffDetailTuple2d, WaveletDetailTuple2d from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, @@ -211,13 +211,15 @@ def wavedec2( if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) - result_lst: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + result_lst: list[WaveletDetailTuple2d] = [] res_ll = data for _ in range(level): res_ll = _fwt_pad2(res_ll, wavelet, mode=mode) res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2) res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) - to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1)) + to_append = WaveletDetailTuple2d( + res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1) + ) result_lst.append(to_append) result_lst.reverse() diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index d9424ea9..27cb016d 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -24,7 +24,12 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import OrthogonalizeMethod, PaddingMode, WaveletCoeffDetailTuple2d +from .constants import ( + OrthogonalizeMethod, + PaddingMode, + WaveletCoeffDetailTuple2d, + WaveletDetailTuple2d, +) from .conv_transform import _get_filter_tensors from .conv_transform_2 import ( _construct_2d_filt, @@ -463,7 +468,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + split_list: list[WaveletDetailTuple2d] = [] if self.separable: ll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): @@ -485,7 +490,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: ll, lh = torch.split(a_coeffs, current_width // 2, dim=-1) hl, hh = torch.split(d_coeffs, current_width // 2, dim=-1) - split_list.append((lh, hl, hh)) + split_list.append(WaveletDetailTuple2d(lh, hl, hh)) else: ll = input_signal.transpose(-2, -1).reshape([batch_size, -1]).T for scale, fwt_matrix in enumerate(self.fwt_matrix_list): @@ -525,7 +530,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: for el in four_split[1:] ), ) - split_list.append(reshaped) + split_list.append(WaveletDetailTuple2d(*reshaped)) ll = four_split[0] ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index fee4666a..62245ad3 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -18,6 +18,7 @@ OrthogonalizeMethod, WaveletCoeffDetailDict, WaveletCoeffDetailTuple2d, + WaveletDetailTuple2d, ) from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 @@ -369,7 +370,7 @@ def reconstruct(self) -> "WaveletPacket2D": data_v = self[node + "v"] data_d = self[node + "d"] rec = self._get_waverec(data_a.shape[-2:])( - (data_a, (data_h, data_v, data_d)) + (data_a, WaveletDetailTuple2d(data_h, data_v, data_d)) ) if level > 0: if rec.shape[-1] != self[node].shape[-1]: @@ -458,7 +459,10 @@ def _tuple_func( # assert for type checking assert len(fs_dict_data) == 2 a_coeff, fsdict = fs_dict_data - return (a_coeff, (fsdict["ad"], fsdict["da"], fsdict["dd"])) + return ( + a_coeff, + WaveletDetailTuple2d(fsdict["ad"], fsdict["da"], fsdict["dd"]), + ) return _tuple_func From b1c34d56b2feaad7167f4e28d79ec2ee6eb87bcb Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 02:33:57 +0200 Subject: [PATCH 29/40] Change type str back to imperative mood --- src/ptwt/matmul_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 30b64cfc..a8e52fe1 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -188,7 +188,7 @@ def __init__( axis: Optional[int] = -1, boundary: OrthogonalizeMethod = "qr", ) -> None: - """A sparse matrix fast wavelet transform object. + """Create a sparse matrix fast wavelet transform object. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or From d4ed327246688c1a191362901e160e64fbe2158f Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 03:06:38 +0200 Subject: [PATCH 30/40] Add comment clarification --- src/ptwt/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 9adffced..6c34de50 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -79,6 +79,7 @@ class WaveletDetailTuple2d(NamedTuple): """ +# Note: This data structure was chosen to follow pywt's conventions WaveletCoeffDetailTuple2d: TypeAlias = tuple[ torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] ] @@ -96,6 +97,7 @@ class WaveletDetailTuple2d(NamedTuple): Alias of ``tuple[torch.Tensor, *tuple[WaveletDetailTuple2d, ...]]`` """ +# Note: This data structure was chosen to follow pywt's conventions WaveletCoeffDetailDict: TypeAlias = tuple[ torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]] ] From 1991a071ea7c2e1506392ef7bf2c31aeaa63f56c Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 13:05:08 +0200 Subject: [PATCH 31/40] Rename type aliases. Rename * `WaveletCoeffDetailDict` to `WaveletCoeffNd` * `WaveletCoeffDetailTuple2d` to `WaveletCoeff2d` --- docs/conf.py | 4 ++-- src/ptwt/__init__.py | 2 +- src/ptwt/_util.py | 22 ++++++++-------------- src/ptwt/constants.py | 10 ++++------ src/ptwt/conv_transform.py | 6 +++--- src/ptwt/conv_transform_2.py | 18 +++++++++--------- src/ptwt/conv_transform_3.py | 18 +++++++++--------- src/ptwt/matmul_transform_2.py | 14 +++++++------- src/ptwt/matmul_transform_3.py | 14 +++++++------- src/ptwt/packets.py | 20 ++++++++++---------- src/ptwt/separable_conv_transform.py | 28 ++++++++++++++-------------- 11 files changed, 74 insertions(+), 82 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e0aaa4c3..d31a6b76 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,7 +82,7 @@ numfig = True autodoc_type_aliases = { - "WaveletCoeffDetailTuple2d": "ptwt.constants.WaveletCoeffDetailTuple2d", - "WaveletCoeffDetailDict": "ptwt.constants.WaveletCoeffDetailDict", + "WaveletCoeff2d": "ptwt.constants.WaveletCoeff2d", + "WaveletCoeffNd": "ptwt.constants.WaveletCoeffNd", "BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec", } diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index d0ad9418..5ac62411 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -1,7 +1,7 @@ """Differentiable and gpu enabled fast wavelet transforms in PyTorch.""" from ._util import Wavelet -from .constants import WaveletCoeffDetailDict, WaveletCoeffDetailTuple2d +from .constants import WaveletCoeff2d, WaveletCoeffNd from .continuous_transform import cwt from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index cc742054..812546c4 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -10,11 +10,7 @@ import pywt import torch -from .constants import ( - OrthogonalizeMethod, - WaveletCoeffDetailDict, - WaveletCoeffDetailTuple2d, -) +from .constants import OrthogonalizeMethod, WaveletCoeff2d, WaveletCoeffNd class Wavelet(Protocol): @@ -172,22 +168,22 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: @overload def _map_result( - data: WaveletCoeffDetailTuple2d, + data: WaveletCoeff2d, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailTuple2d: ... +) -> WaveletCoeff2d: ... @overload def _map_result( - data: WaveletCoeffDetailDict, + data: WaveletCoeffNd, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailDict: ... +) -> WaveletCoeffNd: ... def _map_result( - data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], + data: Union[WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], -) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]: +) -> Union[WaveletCoeff2d, WaveletCoeffNd]: approx = function(data[0]) result_lst: list[ Union[ @@ -211,7 +207,5 @@ def _map_result( raise ValueError(f"Unexpected input type {type(element)}") return_val = approx, *result_lst - return_val = cast( - Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val - ) + return_val = cast(Union[WaveletCoeff2d, WaveletCoeffNd], return_val) return return_val diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 6c34de50..5b274e3a 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -11,8 +11,8 @@ "PaddingMode", "OrthogonalizeMethod", "WaveletDetailTuple2d", - "WaveletCoeffDetailTuple2d", - "WaveletCoeffDetailDict", + "WaveletCoeff2d", + "WaveletCoeffNd", "WaveletDetailDict", ] @@ -80,7 +80,7 @@ class WaveletDetailTuple2d(NamedTuple): # Note: This data structure was chosen to follow pywt's conventions -WaveletCoeffDetailTuple2d: TypeAlias = tuple[ +WaveletCoeff2d: TypeAlias = tuple[ torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] ] """Type alias for 2d wavelet transform results. @@ -98,9 +98,7 @@ class WaveletDetailTuple2d(NamedTuple): """ # Note: This data structure was chosen to follow pywt's conventions -WaveletCoeffDetailDict: TypeAlias = tuple[ - torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]] -] +WaveletCoeffNd: TypeAlias = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] """Type alias for wavelet transform results in any dimension. This type alias represents the result of a Nd wavelet transform diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 2aa0cc4a..f4898961 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -20,7 +20,7 @@ _pad_symmetric, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffDetailTuple2d +from .constants import BoundaryMode, WaveletCoeff2d def _create_tensor( @@ -168,13 +168,13 @@ def _fwt_pad( def _flatten_2d_coeff_lst( - coeff_lst_2d: WaveletCoeffDetailTuple2d, + coeff_lst_2d: WaveletCoeff2d, flatten_tensors: bool = True, ) -> list[torch.Tensor]: """Flattens a sequence of tensor tuples into a single list. Args: - coeff_lst_2d (WaveletCoeffDetailTuple2d): A pywt-style + coeff_lst_2d (WaveletCoeff2d): A pywt-style coefficient tuple of torch tensors. flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True. diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 4a6a48f3..1136916f 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -27,7 +27,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffDetailTuple2d, WaveletDetailTuple2d +from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, @@ -100,8 +100,8 @@ def _fwt_pad2( def _waverec2d_fold_channels_2d_list( - coeffs: WaveletCoeffDetailTuple2d, -) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: + coeffs: WaveletCoeff2d, +) -> tuple[WaveletCoeff2d, list[int]]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds @@ -132,7 +132,7 @@ def wavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> WaveletCoeffDetailTuple2d: +) -> WaveletCoeff2d: r"""Run a two-dimensional wavelet transformation. This function relies on two-dimensional convolutions. @@ -173,7 +173,7 @@ def wavedec2( Returns: A tuple containing the wavelet coefficients in pywt order, - see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. + see :data:`ptwt.constants.WaveletCoeff2d`. Raises: ValueError: If the dimensionality or the dtype of the input data tensor @@ -224,7 +224,7 @@ def wavedec2( result_lst.reverse() res_ll = res_ll.squeeze(1) - result: WaveletCoeffDetailTuple2d = res_ll, *result_lst + result: WaveletCoeff2d = res_ll, *result_lst if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) @@ -238,7 +238,7 @@ def wavedec2( def waverec2( - coeffs: WaveletCoeffDetailTuple2d, + coeffs: WaveletCoeff2d, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -248,8 +248,8 @@ def waverec2( or forward transform by running transposed convolutions. Args: - coeffs (WaveletCoeffDetailTuple2d): The wavelet coefficient tuple - produced by wavedec2. See :data:`ptwt.constants.WaveletCoeffDetailTuple2d` + coeffs (WaveletCoeff2d): The wavelet coefficient tuple + produced by wavedec2. See :data:`ptwt.constants.WaveletCoeff2d` wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 38c114df..1e555cff 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -26,7 +26,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffDetailDict +from .constants import BoundaryMode, WaveletCoeffNd from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, @@ -110,7 +110,7 @@ def wavedec3( mode: BoundaryMode = "zero", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> WaveletCoeffDetailDict: +) -> WaveletCoeffNd: """Compute a three-dimensional wavelet transform. Args: @@ -130,7 +130,7 @@ def wavedec3( Returns: A tuple containing the wavelet coefficients, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. Raises: ValueError: If the input has fewer than three dimensions or @@ -194,7 +194,7 @@ def wavedec3( } ) result_lst.reverse() - result: WaveletCoeffDetailDict = res_lll, *result_lst + result: WaveletCoeffNd = res_lll, *result_lst if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) @@ -208,9 +208,9 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( - coeffs: WaveletCoeffDetailDict, + coeffs: WaveletCoeffNd, ) -> tuple[ - WaveletCoeffDetailDict, + WaveletCoeffNd, list[int], ]: # fold the input coefficients for processing conv2d_transpose. @@ -225,15 +225,15 @@ def _waverec3d_fold_channels_3d_list( def waverec3( - coeffs: WaveletCoeffDetailDict, + coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. Args: - coeffs (WaveletCoeffDetailDict): The wavelet coefficient tuple - produced by wavedec3, see :data:`ptwt.constants.WaveletCoeffDetailDict`. + coeffs (WaveletCoeffNd): The wavelet coefficient tuple + produced by wavedec3, see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 27cb016d..ca629f14 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -27,7 +27,7 @@ from .constants import ( OrthogonalizeMethod, PaddingMode, - WaveletCoeffDetailTuple2d, + WaveletCoeff2d, WaveletDetailTuple2d, ) from .conv_transform import _get_filter_tensors @@ -413,7 +413,7 @@ def _construct_analysis_matrices( current_width = current_width // 2 self.size_list.append((current_height, current_width)) - def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call @@ -428,7 +428,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: Returns: The resulting coefficients per level are stored in a pywt style tuple, - see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. + see :data:`ptwt.constants.WaveletCoeff2d`. Raises: ValueError: If the decomposition level is not a positive integer @@ -535,7 +535,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) split_list.reverse() - result: WaveletCoeffDetailTuple2d = ll, *split_list + result: WaveletCoeff2d = ll, *split_list if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) @@ -716,14 +716,14 @@ def _construct_synthesis_matrices( def __call__( self, - coefficients: WaveletCoeffDetailTuple2d, + coefficients: WaveletCoeff2d, ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. Args: - coefficients (WaveletCoeffDetailTuple2d): The coefficient tuple as returned + coefficients (WaveletCoeff2d): The coefficient tuple as returned by the `MatrixWavedec2` object, - see :data:`ptwt.constants.WaveletCoeffDetailTuple2d`. + see :data:`ptwt.constants.WaveletCoeff2d`. Returns: The original signal reconstruction. For example of shape diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index e4d4f0af..dba3fae0 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -22,7 +22,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import OrthogonalizeMethod, WaveletCoeffDetailDict +from .constants import OrthogonalizeMethod, WaveletCoeffNd from .conv_transform_3 import _waverec3d_fold_channels_3d_list from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm @@ -160,7 +160,7 @@ def _construct_analysis_matrices( ) self.size_list.append((current_depth, current_height, current_width)) - def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: """Compute a separable 3d-boundary wavelet transform. Args: @@ -169,7 +169,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: Returns: The resulting coefficients for each level are stored in a tuple, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. Raises: ValueError: If the input dimensions don't work. @@ -262,7 +262,7 @@ def _split_rec( split_list.append(coeff_dict) split_list.reverse() - result: WaveletCoeffDetailDict = lll, *split_list + result: WaveletCoeffNd = lll, *split_list if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) @@ -388,13 +388,13 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten return cat_tensor return self._cat_coeff_recursive(done_dict) - def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor: + def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: - coefficients (WaveletCoeffDetailDict): + coefficients (WaveletCoeffNd): The output from the `MatrixWavedec3` object, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. Returns: torch.Tensor: A reconstruction of the original signal. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 62245ad3..33b984de 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -16,8 +16,8 @@ from .constants import ( ExtendedBoundaryMode, OrthogonalizeMethod, - WaveletCoeffDetailDict, - WaveletCoeffDetailTuple2d, + WaveletCoeff2d, + WaveletCoeffNd, WaveletDetailTuple2d, ) from .conv_transform import wavedec, waverec @@ -399,7 +399,7 @@ def get_natural_order(self, level: int) -> list[str]: def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ [torch.Tensor], - WaveletCoeffDetailTuple2d, + WaveletCoeff2d, ]: if self.mode == "boundary": shape = tuple(shape) @@ -430,7 +430,7 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ def _get_waverec( self, shape: tuple[int, ...] - ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: + ) -> Callable[[WaveletCoeff2d], torch.Tensor]: if self.mode == "boundary": shape = tuple(shape) if shape not in self.matrix_waverec2_dict.keys(): @@ -450,11 +450,11 @@ def _get_waverec( def _transform_fsdict_to_tuple_func( self, - fs_dict_func: Callable[[torch.Tensor], WaveletCoeffDetailDict], - ) -> Callable[[torch.Tensor], WaveletCoeffDetailTuple2d]: + fs_dict_func: Callable[[torch.Tensor], WaveletCoeffNd], + ) -> Callable[[torch.Tensor], WaveletCoeff2d]: def _tuple_func( data: torch.Tensor, - ) -> WaveletCoeffDetailTuple2d: + ) -> WaveletCoeff2d: fs_dict_data = fs_dict_func(data) # assert for type checking assert len(fs_dict_data) == 2 @@ -468,9 +468,9 @@ def _tuple_func( def _transform_tuple_to_fsdict_func( self, - fsdict_func: Callable[[WaveletCoeffDetailDict], torch.Tensor], - ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: - def _fsdict_func(coeffs: WaveletCoeffDetailTuple2d) -> torch.Tensor: + fsdict_func: Callable[[WaveletCoeffNd], torch.Tensor], + ) -> Callable[[WaveletCoeff2d], torch.Tensor]: + def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor: # assert for type checking assert len(coeffs) == 2 a, (h, v, d) = coeffs diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index ff169ec8..6bccaac5 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -27,7 +27,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffDetailDict +from .constants import BoundaryMode, WaveletCoeffNd from .conv_transform import wavedec, waverec from .conv_transform_2 import _preprocess_tensor_dec2d @@ -114,7 +114,7 @@ def _separable_conv_wavedecn( *, mode: BoundaryMode = "reflect", level: Optional[int] = None, -) -> WaveletCoeffDetailDict: +) -> WaveletCoeffNd: """Compute a multilevel separable padded wavelet analysis transform. Args: @@ -151,13 +151,13 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( - coeffs: WaveletCoeffDetailDict, + coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: - coeffs (WaveletCoeffDetailDict): + coeffs (WaveletCoeffNd): The output as produced by `_separable_conv_wavedecn`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used by ``_separable_conv_wavedecn``. @@ -188,7 +188,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> WaveletCoeffDetailDict: +) -> WaveletCoeffNd: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -208,7 +208,7 @@ def fswavedec2( Returns: A tuple with the ll coefficients and for each scale a dictionary containing the detail coefficients, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. The dictionaries use the filter order strings:: ("ad", "da", "dd") @@ -257,7 +257,7 @@ def fswavedec3( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), -) -> WaveletCoeffDetailDict: +) -> WaveletCoeffNd: """Compute a fully separable 3D-padded analysis wavelet transform. Args: @@ -276,7 +276,7 @@ def fswavedec3( Returns: A tuple with the lll coefficients and for each scale a dictionary containing the detail coefficients, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. The dictionaries use the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") @@ -324,7 +324,7 @@ def fswavedec3( def fswaverec2( - coeffs: WaveletCoeffDetailDict, + coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -334,9 +334,9 @@ def fswaverec2( the hood. Args: - coeffs (WaveletCoeffDetailDict): + coeffs (WaveletCoeffNd): The wavelet coefficients as computed by `fswavedec2`, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` @@ -392,16 +392,16 @@ def fswaverec2( def fswaverec3( - coeffs: WaveletCoeffDetailDict, + coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeffs (WaveletCoeffDetailDict): + coeffs (WaveletCoeffNd): The wavelet coefficients as computed by `fswavedec3`, - see :data:`ptwt.constants.WaveletCoeffDetailDict`. + see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` From 3939070116a9b6364e9fb03b39825e429dc4bf68 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 13:17:16 +0200 Subject: [PATCH 32/40] Introduce WaveletCoeff2dSeparable alias --- docs/conf.py | 1 + src/ptwt/constants.py | 13 +++++++++++++ src/ptwt/separable_conv_transform.py | 12 ++++++------ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d31a6b76..46c25809 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,6 +83,7 @@ autodoc_type_aliases = { "WaveletCoeff2d": "ptwt.constants.WaveletCoeff2d", + "WaveletCoeff2dSeparable": "ptwt.constants.WaveletCoeff2dSeparable", "WaveletCoeffNd": "ptwt.constants.WaveletCoeffNd", "BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec", } diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 5b274e3a..090c0b9a 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -12,6 +12,7 @@ "OrthogonalizeMethod", "WaveletDetailTuple2d", "WaveletCoeff2d", + "WaveletCoeff2dSeparable", "WaveletCoeffNd", "WaveletDetailDict", ] @@ -112,3 +113,15 @@ class WaveletDetailTuple2d(NamedTuple): Alias of ``tuple[torch.Tensor, *tuple[WaveletDetailDict, ...]]`` """ + +WaveletCoeff2dSeparable: TypeAlias = WaveletCoeffNd +"""Type alias for separable 2d wavelet transform results. + +This is an alias of :data:`ptwt.constants.WaveletCoeffNd`. +It is used to emphasize the use of :data:`ptwt.constants.WaveletDetailDict` +for detail coefficients in a 2d setting -- in contrast to +:data:`ptwt.constants.WaveletCoeff2d`. + +Alias of :data:`ptwt.constants.WaveletCoeffNd`, i.e. of +``tuple[torch.Tensor, *tuple[WaveletDetailDict, ...]]``. +""" diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 6bccaac5..c89b9931 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -27,7 +27,7 @@ _undo_swap_axes, _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeffNd +from .constants import BoundaryMode, WaveletCoeff2dSeparable, WaveletCoeffNd from .conv_transform import wavedec, waverec from .conv_transform_2 import _preprocess_tensor_dec2d @@ -188,7 +188,7 @@ def fswavedec2( mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), -) -> WaveletCoeffNd: +) -> WaveletCoeff2dSeparable: """Compute a fully separable 2D-padded analysis wavelet transform. Args: @@ -208,7 +208,7 @@ def fswavedec2( Returns: A tuple with the ll coefficients and for each scale a dictionary containing the detail coefficients, - see :data:`ptwt.constants.WaveletCoeffNd`. + see :data:`ptwt.constants.WaveletCoeff2dSeparable`. The dictionaries use the filter order strings:: ("ad", "da", "dd") @@ -324,7 +324,7 @@ def fswavedec3( def fswaverec2( - coeffs: WaveletCoeffNd, + coeffs: WaveletCoeff2dSeparable, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: @@ -334,9 +334,9 @@ def fswaverec2( the hood. Args: - coeffs (WaveletCoeffNd): + coeffs (WaveletCoeff2dSeparable): The wavelet coefficients as computed by `fswavedec2`, - see :data:`ptwt.constants.WaveletCoeffNd`. + see :data:`ptwt.constants.WaveletCoeff2dSeparable`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` From b32472820b27640c89dd060e1f02be283a717583 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 13:29:22 +0200 Subject: [PATCH 33/40] Make cast in _map_result more narrow and fix tuple creation --- src/ptwt/_util.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 812546c4..446eb63e 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -10,7 +10,13 @@ import pywt import torch -from .constants import OrthogonalizeMethod, WaveletCoeff2d, WaveletCoeffNd +from .constants import ( + OrthogonalizeMethod, + WaveletCoeff2d, + WaveletCoeffNd, + WaveletDetailDict, + WaveletDetailTuple2d, +) class Wavelet(Protocol): @@ -187,14 +193,14 @@ def _map_result( approx = function(data[0]) result_lst: list[ Union[ - tuple[torch.Tensor, torch.Tensor, torch.Tensor], - dict[str, torch.Tensor], + WaveletDetailDict, + WaveletDetailTuple2d, ] ] = [] for element in data[1:]: if isinstance(element, tuple): result_lst.append( - ( + WaveletDetailTuple2d( function(element[0]), function(element[1]), function(element[2]), @@ -206,6 +212,8 @@ def _map_result( else: raise ValueError(f"Unexpected input type {type(element)}") - return_val = approx, *result_lst - return_val = cast(Union[WaveletCoeff2d, WaveletCoeffNd], return_val) - return return_val + # cast since we assume that the full list is of the same type + cast_result_lst = cast( + Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst + ) + return approx, *cast_result_lst From a55f223ac4fd3325794cc933ace4718e8d4a3194 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 13:43:10 +0200 Subject: [PATCH 34/40] Update typing in tests --- tests/test_convolution_fwt.py | 25 +++++++++++++------------ tests/test_convolution_fwt_3.py | 10 +++++----- tests/test_jit.py | 20 ++++++++++---------- tests/test_matrix_fwt.py | 6 +++--- tests/test_matrix_fwt_2.py | 14 +++++++------- tests/test_matrix_fwt_3.py | 10 +++++----- tests/test_packets.py | 12 ++++++------ tests/test_separable_conv_fwt.py | 7 ++++--- tests/test_sparse_math.py | 16 +++++++--------- tests/test_swt.py | 4 ++-- tests/test_util.py | 10 ++++------ tests/test_wavelet.py | 4 +--- 12 files changed, 67 insertions(+), 71 deletions(-) diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index beadddf8..b7572595 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -1,6 +1,7 @@ """Test the conv-fwt code.""" -from typing import Iterable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Union # Written by moritz ( @ wolter.tech ) in 2021 import numpy as np @@ -67,7 +68,7 @@ def test_conv_fwt1d( @pytest.mark.parametrize("size", [[5, 10, 64], [1, 1, 32]]) @pytest.mark.parametrize("wavelet", ["haar", "db2"]) -def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None: +def test_conv_fwt1d_channel(size: list[int], wavelet: str) -> None: """Test channel dimension support.""" data = torch.randn(*size).type(torch.float64) ptwt_coeff = wavedec(data, wavelet) @@ -84,7 +85,7 @@ def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None: @pytest.mark.parametrize("size", [[32], [64]]) @pytest.mark.parametrize("wavelet", ["haar", "db2"]) -def test_conv_fwt1d_nobatch(size: List[int], wavelet: str) -> None: +def test_conv_fwt1d_nobatch(size: list[int], wavelet: str) -> None: """1d conv for inputs without batch dim.""" data = torch.randn(*size).type(torch.float64) ptwt_coeff = wavedec(data, wavelet) @@ -104,7 +105,7 @@ def test_ripples_haar_lvl3() -> None: class _MyHaarFilterBank: @property - def filter_bank(self) -> Tuple[List[float], ...]: + def filter_bank(self) -> tuple[list[float], ...]: """Unscaled Haar wavelet filters.""" return ( [1 / 2, 1 / 2.0], @@ -241,7 +242,7 @@ def test_outer() -> None: "mode", ["reflect", "zero", "constant", "periodic", "symmetric"] ) def test_2d_wavedec_rec( - wavelet_str: str, level: Optional[int], size: Tuple[int, int], mode: BoundaryMode + wavelet_str: str, level: Optional[int], size: tuple[int, int], mode: BoundaryMode ) -> None: """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients. @@ -277,7 +278,7 @@ def test_2d_wavedec_rec( @pytest.mark.parametrize("level", [1, None]) @pytest.mark.parametrize("wavelet", ["haar", "sym3"]) def test_input_4d( - size: Tuple[int, int, int, int], level: Optional[str], wavelet: str + size: tuple[int, int, int, int], level: Optional[str], wavelet: str ) -> None: """Test the error for 4d inputs to wavedec2.""" data = torch.randn(*size).type(torch.float64) @@ -319,9 +320,9 @@ def test_input_1d_dimension_error() -> None: def _compare_coeffs( - ptwt_res: Iterable[Union[torch.Tensor, Tuple[torch.Tensor, ...]]], - pywt_res: Iterable[Union[torch.Tensor, Tuple[torch.Tensor, ...]]], -) -> List[bool]: + ptwt_res: Sequence[Union[torch.Tensor, tuple[torch.Tensor, ...]]], + pywt_res: Sequence[Union[torch.Tensor, tuple[torch.Tensor, ...]]], +) -> list[bool]: """Compare coefficient lists. Args: @@ -334,7 +335,7 @@ def _compare_coeffs( Raises: TypeError: In case of a problem with the list structures. """ - test_list: List[bool] = [] + test_list: list[bool] = [] for ptwtcs, pywtcs in zip(ptwt_res, pywt_res): if isinstance(ptwtcs, tuple) and isinstance(pywtcs, tuple): test_list.extend( @@ -352,7 +353,7 @@ def _compare_coeffs( @pytest.mark.parametrize( "size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)] ) -def test_2d_multidim_input(size: Tuple[int, ...]) -> None: +def test_2d_multidim_input(size: tuple[int, ...]) -> None: """Test the error for multi-dimensional inputs to wavedec2.""" data = torch.randn(*size, dtype=torch.float64) wavelet = "db2" @@ -374,7 +375,7 @@ def test_2d_multidim_input(size: Tuple[int, ...]) -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)]) -def test_2d_axis_argument(axes: Tuple[int, int]) -> None: +def test_2d_axis_argument(axes: tuple[int, int]) -> None: """Ensure the axes argument works as expected.""" data = torch.randn([32, 32, 32, 32], dtype=torch.float64) diff --git a/tests/test_convolution_fwt_3.py b/tests/test_convolution_fwt_3.py index 36de97e3..6338e759 100644 --- a/tests/test_convolution_fwt_3.py +++ b/tests/test_convolution_fwt_3.py @@ -1,7 +1,7 @@ """Test our 3d for loop-convolution based fwt code.""" import typing -from typing import Any, Dict, List, Union +from typing import Any, Union import numpy as np import numpy.typing as npt @@ -14,8 +14,8 @@ def _expand_dims( - batch_list: List[Union[npt.NDArray[Any], Dict[Any, Any]]] -) -> List[Any]: + batch_list: list[Union[npt.NDArray[Any], dict[Any, Any]]] +) -> list[Any]: for pos, bel in enumerate(batch_list): if isinstance(bel, np.ndarray): batch_list[pos] = np.expand_dims(bel, 0) @@ -66,7 +66,7 @@ def _cat_batch_list(batch_lists: Any) -> Any: @pytest.mark.parametrize("level", [1, 2, None]) @pytest.mark.parametrize("mode", typing.get_args(BoundaryMode)) def test_waverec3( - shape: List[int], wavelet: str, level: int, mode: BoundaryMode + shape: list[int], wavelet: str, level: int, mode: BoundaryMode ) -> None: """Ensure the 3d analysis transform is invertible.""" data = np.random.randn(*shape) @@ -102,7 +102,7 @@ def test_waverec3( @pytest.mark.parametrize("level", [1, 2, None]) @pytest.mark.parametrize("wavelet", ["haar", "sym3", "db3"]) @pytest.mark.parametrize("mode", ["zero", "symmetric", "reflect"]) -def test_multidim_input(size: List[int], level: int, wavelet: str, mode: str) -> None: +def test_multidim_input(size: list[int], level: int, wavelet: str, mode: str) -> None: """Ensure correct folding of multidimensional inputs.""" data = torch.randn(size, dtype=torch.float64) ptwc = ptwt.wavedec3(data, wavelet, level=level, mode=mode) diff --git a/tests/test_jit.py b/tests/test_jit.py index a7a0d6a8..c06eda52 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,6 +1,6 @@ """Ensure pytorch's torch.jit.trace feature works properly.""" -from typing import Dict, List, NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Union import numpy as np import pytest @@ -33,7 +33,7 @@ def _set_up_wavelet_tuple(wavelet: WaveletTuple, dtype: torch.dtype) -> WaveletT def _to_jit_wavedec_fun( data: torch.Tensor, wavelet: Union[ptwt.Wavelet, str], level: Optional[int] -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: return ptwt.wavedec(data, wavelet, mode="reflect", level=level) @@ -69,10 +69,10 @@ def test_conv_fwt_jit( def _to_jit_wavedec_2( data: torch.Tensor, wavelet: Union[str, ptwt.Wavelet] -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: """Ensure uniform datatypes in lists for the tracer. - Going from List[Union[torch.Tensor, Tuple[torch.Tensor]]] to List[torch.Tensor] + Going from list[Union[torch.Tensor, tuple[torch.Tensor]]] to list[torch.Tensor] means we have to stack the lists in the output. """ assert data.shape == (10, 20, 20), "Changing the chape requires re-tracing." @@ -87,10 +87,10 @@ def _to_jit_wavedec_2( def _to_jit_waverec_2( - data: List[torch.Tensor], wavelet: Union[str, ptwt.Wavelet] + data: list[torch.Tensor], wavelet: Union[str, ptwt.Wavelet] ) -> torch.Tensor: """Undo the stacking from the jit wavedec2 wrapper.""" - d_unstack: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = [data[0]] + d_unstack: list[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = [data[0]] for c in data[1:]: d_unstack.append(tuple(sc.squeeze(0) for sc in torch.split(c, 1, dim=0))) rec = ptwt.waverec2(d_unstack, wavelet) @@ -121,10 +121,10 @@ def test_conv_fwt_jit_2d() -> None: assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7) -def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> List[torch.Tensor]: +def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> list[torch.Tensor]: """Ensure uniform datatypes in lists for the tracer. - Going from List[Union[torch.Tensor, Dict[str, torch.Tensor]]] to List[torch.Tensor] + Going from list[Union[torch.Tensor, dict[str, torch.Tensor]]] to list[torch.Tensor] means we have to stack the lists in the output. """ assert data.shape == (10, 20, 20, 20), "Changing the shape requires re-tracing." @@ -139,9 +139,9 @@ def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> List[torch.Tensor]: return coeff2 -def _to_jit_waverec_3(data: List[torch.Tensor], wavelet: pywt.Wavelet) -> torch.Tensor: +def _to_jit_waverec_3(data: list[torch.Tensor], wavelet: pywt.Wavelet) -> torch.Tensor: """Undo the stacking from the jit wavedec3 wrapper.""" - d_unstack: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [data[0]] + d_unstack: list[Union[torch.Tensor, dict[str, torch.Tensor]]] = [data[0]] keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd") for c in data[1:]: d_unstack.append( diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index 0dfe5d33..b83b03df 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -2,7 +2,7 @@ # Written by moritz ( @ wolter.tech ) in 2021 -from typing import Any, List +from typing import Any import numpy as np import numpy.typing as npt @@ -71,7 +71,7 @@ def test_fwt_ifwt_mackey_haar_cuda() -> None: @pytest.mark.parametrize("level", [1, 2, 3, 4, None]) @pytest.mark.parametrize("wavelet", ["db2", "db3", "db4", "sym5"]) @pytest.mark.parametrize("size", [[2, 256], [2, 3, 256], [1, 1, 128]]) -def test_1d_matrix_fwt_ifwt(level: int, wavelet: str, size: List[int]) -> None: +def test_1d_matrix_fwt_ifwt(level: int, wavelet: str, size: list[int]) -> None: """Test multiple wavelets and levels for a long signal.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") wavelet = pywt.Wavelet(wavelet) @@ -196,7 +196,7 @@ def test_4d_invalid_axis_error() -> None: @pytest.mark.parametrize("size", [[2, 3, 32], [5, 32], [32], [1, 1, 64]]) -def test_matrix1d_batch_channel(size: List[int]) -> None: +def test_matrix1d_batch_channel(size: list[int]) -> None: """Test if batch and channel support works as expected.""" data = torch.randn(*size).type(torch.float64) matrix_wavedec_1d = MatrixWavedec("haar", 3) diff --git a/tests/test_matrix_fwt_2.py b/tests/test_matrix_fwt_2.py index 6305bd85..92a07197 100644 --- a/tests/test_matrix_fwt_2.py +++ b/tests/test_matrix_fwt_2.py @@ -1,6 +1,6 @@ """Test code for the 2d boundary wavelets.""" -from typing import List, Tuple, Type +from typing import Type import numpy as np import pytest @@ -23,7 +23,7 @@ @pytest.mark.parametrize("size", [(16, 16), (16, 8), (8, 16)]) @pytest.mark.parametrize("wavelet_str", ["db1", "db2", "db3", "db4", "db5"]) -def test_analysis_synthesis_matrices2(size: Tuple[int, int], wavelet_str: str) -> None: +def test_analysis_synthesis_matrices2(size: tuple[int, int], wavelet_str: str) -> None: """Test the 2d analysis and synthesis matrices for various wavelets.""" wavelet = pywt.Wavelet(wavelet_str) a = construct_boundary_a2( @@ -49,7 +49,7 @@ def test_analysis_synthesis_matrices2(size: Tuple[int, int], wavelet_str: str) - @pytest.mark.slow @pytest.mark.parametrize("size", [(8, 16), (16, 8), (15, 16), (16, 15), (16, 16)]) @pytest.mark.parametrize("level", [1, 2, 3]) -def test_matrix_analysis_fwt_2d_haar(size: Tuple[int, int], level: int) -> None: +def test_matrix_analysis_fwt_2d_haar(size: tuple[int, int], level: int) -> None: """Test the fwt-2d matrix-haar transform, should be equal to the pywt.""" face = np.mean( scipy.datasets.face()[256 : (256 + size[0]), 256 : (256 + size[1])], -1 @@ -87,7 +87,7 @@ def test_matrix_analysis_fwt_2d_haar(size: Tuple[int, int], level: int) -> None: @pytest.mark.parametrize("level", [1, 2, 3, None]) @pytest.mark.parametrize("separable", [False, True]) def test_boundary_matrix_fwt_2d( - wavelet_str: str, size: Tuple[int, int], level: int, separable: bool + wavelet_str: str, size: tuple[int, int], level: int, separable: bool ) -> None: """Ensure the boundary matrix fwt is invertable.""" face = np.mean( @@ -119,7 +119,7 @@ def test_boundary_matrix_fwt_2d( @pytest.mark.parametrize("size", [(16, 16), (32, 16), (16, 32)]) @pytest.mark.parametrize("separable", [False, True]) def test_batched_2d_matrix_fwt_ifwt( - wavelet_str: str, level: int, size: Tuple[int, int], separable: bool + wavelet_str: str, level: int, size: tuple[int, int], separable: bool ) -> None: """Ensure the batched matrix fwt works properly.""" face = scipy.datasets.face()[256 : (256 + size[0]), 256 : (256 + size[1])].astype( @@ -192,7 +192,7 @@ def test_separable_haar_2d() -> None: @pytest.mark.parametrize("size", [[3, 2, 32, 32], [4, 32, 32], [1, 1, 32, 32]]) -def test_batch_channel_2d_haar(size: List[int]) -> None: +def test_batch_channel_2d_haar(size: list[int]) -> None: """Test matrix fwt-2d leading channel and batch dimension code.""" signal = torch.randn(*size).type(torch.float64) ptwt_coeff = MatrixWavedec2("haar", 2, separable=False)(signal) @@ -235,7 +235,7 @@ def test_empty_inverse_operators(operator: Type[BaseMatrixWaveDec]) -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", ((-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0))) -def test_axes_2d(axes: Tuple[int, int]) -> None: +def test_axes_2d(axes: tuple[int, int]) -> None: """Ensure the axes argument is supported correctly.""" # TODO: write me. data = torch.randn(24, 24, 24, 24, 24).type(torch.float64) diff --git a/tests/test_matrix_fwt_3.py b/tests/test_matrix_fwt_3.py index a14ebadc..2856e591 100644 --- a/tests/test_matrix_fwt_3.py +++ b/tests/test_matrix_fwt_3.py @@ -1,6 +1,6 @@ """Test the 3d matrix-fwt code.""" -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import pytest @@ -16,7 +16,7 @@ @pytest.mark.parametrize( "shape", [(32, 32, 32), (64, 32, 32), (32, 64, 32), (32, 32, 64)] ) -def test_single_dim_mm(axis: int, shape: Tuple[int, int, int]) -> None: +def test_single_dim_mm(axis: int, shape: tuple[int, int, int]) -> None: """Test the transposed matrix multiplication approach.""" test_tensor = torch.rand(4, shape[0], shape[1], shape[2]).type(torch.float64) pywt_dec_lo, pywt_dec_hi = pywt.wavedec( @@ -32,7 +32,7 @@ def test_single_dim_mm(axis: int, shape: Tuple[int, int, int]) -> None: @pytest.mark.parametrize( "shape", [(32, 32, 32), (64, 32, 32), (32, 64, 32), (32, 32, 64)] ) -def test_boundary_wavedec3_level1_haar(shape: Tuple[int, int, int]) -> None: +def test_boundary_wavedec3_level1_haar(shape: tuple[int, int, int]) -> None: """Test a separable boundary 3d-transform.""" batch_size = 1 test_data = torch.rand(batch_size, shape[0], shape[1], shape[2]).type(torch.float64) @@ -76,7 +76,7 @@ def test_boundary_wavedec3_level1_haar(shape: Tuple[int, int, int]) -> None: "shape", [(31, 32, 33), (63, 35, 32), (32, 62, 31), (32, 32, 64)] ) def test_boundary_wavedec3_inverse( - level: Optional[int], shape: Tuple[int, int, int] + level: Optional[int], shape: tuple[int, int, int] ) -> None: """Test the 3d matrix wavedec and the padding for odd axes.""" batch_size = 1 @@ -91,7 +91,7 @@ def test_boundary_wavedec3_inverse( @pytest.mark.slow @pytest.mark.parametrize("axes", [[-3, -2, -1], [0, 2, 1]]) @pytest.mark.parametrize("level", [1, 2, None]) -def test_axes_arg_matrix_3d(axes: List[int], level: int) -> None: +def test_axes_arg_matrix_3d(axes: list[int], level: int) -> None: """Test axes 3d matmul argument support.""" wavelet = "haar" data = torch.randn([16, 16, 16, 16, 16], dtype=torch.float64) diff --git a/tests/test_packets.py b/tests/test_packets.py index f635999e..3e748e43 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -3,7 +3,7 @@ # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) from itertools import product -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import pytest @@ -255,7 +255,7 @@ def test_packet_harbo_lvl3() -> None: class _MyHaarFilterBank(object): @property - def filter_bank(self) -> Tuple[List[float], ...]: + def filter_bank(self) -> tuple[list[float], ...]: """Unscaled Haar wavelet filters.""" return ( [1 / 2, 1 / 2.0], @@ -315,7 +315,7 @@ def test_access_errors_2d() -> None: @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axis", (1, -1)) def test_inverse_packet_1d( - level: int, base_key: str, shape: List[int], wavelet: str, axis: int + level: int, base_key: str, shape: list[int], wavelet: str, axis: int ) -> None: """Test the 1d reconstruction code.""" signal = np.random.randn(*shape) @@ -340,9 +340,9 @@ def test_inverse_packet_1d( def test_inverse_packet_2d( level: int, base_key: str, - size: Tuple[int, ...], + size: tuple[int, ...], wavelet: str, - axes: Tuple[int, int], + axes: tuple[int, int], ) -> None: """Test the 2d reconstruction code.""" signal = np.random.randn(*size) @@ -390,7 +390,7 @@ def test_inverse_boundary_packet_2d() -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", ((-2, -1), (1, 2), (2, 1))) -def test_separable_conv_packets_2d(axes: Tuple[int, int]) -> None: +def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: """Ensure the 2d separable conv code is ok.""" wavelet = "db2" signal = np.random.randn(1, 32, 32, 32) diff --git a/tests/test_separable_conv_fwt.py b/tests/test_separable_conv_fwt.py index 6fc8677f..2358898b 100644 --- a/tests/test_separable_conv_fwt.py +++ b/tests/test_separable_conv_fwt.py @@ -1,6 +1,7 @@ """Separable transform test code.""" -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional import numpy as np import pytest @@ -90,7 +91,7 @@ def test_example_fs3d(shape: Sequence[int], wavelet: str) -> None: ) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (2, 3), (3, 2)]) def test_conv_mm_2d( - level: Optional[int], shape: Sequence[int], axes: Tuple[int, int] + level: Optional[int], shape: Sequence[int], axes: tuple[int, int] ) -> None: """Compare mm and conv fully separable results.""" data = torch.randn(*shape).type(torch.float64) @@ -119,7 +120,7 @@ def test_conv_mm_2d( @pytest.mark.parametrize("axes", [(-3, -2, -1), (-1, -2, -3), (2, 3, 1)]) @pytest.mark.parametrize("shape", [(5, 64, 128, 256)]) def test_conv_mm_3d( - level: Optional[int], axes: Tuple[int, int, int], shape: Tuple[int, ...] + level: Optional[int], axes: tuple[int, int, int], shape: tuple[int, ...] ) -> None: """Compare mm and conv 3d fully separable results.""" data = torch.randn(*shape).type(torch.float64) diff --git a/tests/test_sparse_math.py b/tests/test_sparse_math.py index 2266eab6..9fe30ea6 100644 --- a/tests/test_sparse_math.py +++ b/tests/test_sparse_math.py @@ -1,7 +1,5 @@ """Test the sparse math code from ptwt.sparse_math.""" -from typing import Tuple - import numpy as np import pytest import scipy.signal @@ -131,8 +129,8 @@ def test_strided_conv_matrix( @pytest.mark.parametrize("mode", ["same", "full", "valid"]) @pytest.mark.parametrize("fully_sparse", [True, False]) def test_conv_matrix_2d( - filter_shape: Tuple[int, int], - size: Tuple[int, int], + filter_shape: tuple[int, int], + size: tuple[int, int], mode: PaddingMode, fully_sparse: bool, ) -> None: @@ -175,7 +173,7 @@ def test_conv_matrix_2d( ) @pytest.mark.parametrize("mode", ["full", "valid"]) def test_strided_conv_matrix_2d( - filter_shape: Tuple[int, int], size: Tuple[int, int], mode: PaddingMode + filter_shape: tuple[int, int], size: tuple[int, int], mode: PaddingMode ) -> None: """Test strided convolution matrices with full and valid padding.""" test_filter = torch.rand(filter_shape) @@ -219,7 +217,7 @@ def test_strided_conv_matrix_2d( "size", [(7, 8), (8, 7), (7, 7), (8, 8), (16, 16), (8, 16), (16, 8)] ) def test_strided_conv_matrix_2d_same( - filter_shape: Tuple[int, int], size: Tuple[int, int] + filter_shape: tuple[int, int], size: tuple[int, int] ) -> None: """Test strided conv matrix with same padding.""" stride = 2 @@ -250,8 +248,8 @@ def test_strided_conv_matrix_2d_same( def _get_2d_same_padding( - filter_shape: Tuple[int, int], input_size: Tuple[int, int] -) -> Tuple[int, int, int, int]: + filter_shape: tuple[int, int], input_size: tuple[int, int] +) -> tuple[int, int, int, int]: height_offset = input_size[0] % 2 width_offset = input_size[1] % 2 padding = ( @@ -265,7 +263,7 @@ def _get_2d_same_padding( @pytest.mark.slow @pytest.mark.parametrize("size", [(256, 512), (512, 256)]) -def test_strided_conv_matrix_2d_sameshift(size: Tuple[int, int]) -> None: +def test_strided_conv_matrix_2d_sameshift(size: tuple[int, int]) -> None: """Test strided conv matrix with sameshift padding.""" stride = 2 filter_shape = (3, 3) diff --git a/tests/test_swt.py b/tests/test_swt.py index 147c3295..da828eb6 100644 --- a/tests/test_swt.py +++ b/tests/test_swt.py @@ -1,6 +1,6 @@ """Test the stationary wavelet transformation code.""" -from typing import Optional, Tuple +from typing import Optional import numpy as np import pytest @@ -11,7 +11,7 @@ @pytest.mark.parametrize("shape", [(8,), (1, 8), (4, 8), (4, 6, 8), (4, 6, 8, 8)]) -def test_circular_pad(shape: Tuple[int, ...]) -> None: +def test_circular_pad(shape: tuple[int, ...]) -> None: """Test patched circular padding.""" test_data_np = np.random.rand(*shape).astype(np.float32) test_data_pt = torch.from_numpy(test_data_np) diff --git a/tests/test_util.py b/tests/test_util.py index 7109ac57..b4952a71 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,5 @@ """Test the util methods.""" -from typing import List, Tuple - import numpy as np import pytest import pywt @@ -18,7 +16,7 @@ class _MyHaarFilterBank: @property - def filter_bank(self) -> Tuple[List[float], List[float], List[float], List[float]]: + def filter_bank(self) -> tuple[list[float], list[float], list[float], list[float]]: """Unscaled Haar wavelet filters.""" return ( [1 / 2, 1 / 2.0], @@ -53,7 +51,7 @@ def test_failed_as_wavelet(wavelet: str) -> None: @pytest.mark.parametrize( "pad_list", [(2, 2), (0, 0), (1, 0), (0, 1), (2, 1), (1, 2), (10, 10)] ) -def test_pad_symmetric_1d(size: List[int], pad_list: Tuple[int, int]) -> None: +def test_pad_symmetric_1d(size: list[int], pad_list: tuple[int, int]) -> None: """Test symetric padding in a single dimension.""" test_signal = np.random.randint(0, 9, size=size).astype(np.float32) my_pad = _pad_symmetric_1d(torch.from_numpy(test_signal), pad_list) @@ -63,7 +61,7 @@ def test_pad_symmetric_1d(size: List[int], pad_list: Tuple[int, int]) -> None: @pytest.mark.parametrize("size", [[6, 5], [5, 6], [5, 5], [9, 9], [3, 3]]) @pytest.mark.parametrize("pad_list", [[(1, 4), (4, 1)], [(2, 2), (3, 3)]]) -def test_pad_symmetric(size: List[int], pad_list: List[Tuple[int, int]]) -> None: +def test_pad_symmetric(size: list[int], pad_list: list[tuple[int, int]]) -> None: """Test high-dimensional symetric padding.""" array = np.random.randint(0, 9, size=size) my_pad = _pad_symmetric(torch.from_numpy(array), pad_list) @@ -73,7 +71,7 @@ def test_pad_symmetric(size: List[int], pad_list: List[Tuple[int, int]]) -> None @pytest.mark.parametrize("keep_no", [1, 2, 3]) @pytest.mark.parametrize("size", [[20, 21, 22, 23], [1, 2, 3, 4], [4, 3, 2, 1]]) -def test_fold(keep_no: int, size: List[int]) -> None: +def test_fold(keep_no: int, size: list[int]) -> None: """Ensure channel folding works as expected.""" array = torch.randn(*size).type(torch.float64) folded, ds = _fold_axes(array, keep_no) diff --git a/tests/test_wavelet.py b/tests/test_wavelet.py index 38d0e4ed..5240d24e 100644 --- a/tests/test_wavelet.py +++ b/tests/test_wavelet.py @@ -1,7 +1,5 @@ """Test the adaptive wavelet cost functions.""" -from typing import List - import pytest import pywt import torch @@ -19,7 +17,7 @@ (pywt.wavelist(family="rbio"), False), ], ) -def test_wavelet_lst(lst: List[str], is_orth: bool) -> None: +def test_wavelet_lst(lst: list[str], is_orth: bool) -> None: """Test all wavelets in a list.""" for ws in lst: wavelet = pywt.Wavelet(ws) From 48e87a74f6f7bc5d8aacb4f9ad9863e70003d85c Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 14:46:46 +0200 Subject: [PATCH 35/40] Also import 2d separable type in __init__ --- src/ptwt/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index 5ac62411..a8d8840f 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -1,7 +1,7 @@ """Differentiable and gpu enabled fast wavelet transforms in PyTorch.""" from ._util import Wavelet -from .constants import WaveletCoeff2d, WaveletCoeffNd +from .constants import WaveletCoeff2d, WaveletCoeffNd, WaveletCoeff2dSeparable from .continuous_transform import cwt from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 From 4282c9ddcb83e1f7442b742cb54d4522510a1227 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 14:48:44 +0200 Subject: [PATCH 36/40] Improve type alias docstr --- src/ptwt/constants.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index 090c0b9a..b02e0d11 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -67,11 +67,12 @@ class WaveletDetailTuple2d(NamedTuple): WaveletDetailDict: TypeAlias = dict[str, torch.Tensor] """Type alias for a dict containing detail coefficient for a given level. -Thus type alias represents the detail coefficient tensors of a given level for +This type alias represents the detail coefficient tensors of a given level for a wavelet transform in :math:`N` dimensions as the values of a dictionary. Its keys are a string of length :math:`N` describing the detail coefficient -by the applied filter for each axis where 'a' denotes the low pass -or approximation filter and 'd' the high-pass or detail filter. +by the applied filter for each axis. The string consists only of chars 'a' and 'd' +where 'a' denotes the low pass or approximation filter and 'd' the high-pass +or detail filter. For a 3d transform, the dictionary thus uses the keys:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") @@ -87,9 +88,9 @@ class WaveletDetailTuple2d(NamedTuple): """Type alias for 2d wavelet transform results. This type alias represents the result of a 2d wavelet transform -with :math:`L` levels as a tuple ``(A, T1, T2, ...)`` of length :math:`L + 1` -where ``A`` denotes a tensor of approximation coefficients and -``Tl`` is a tuple of detail coefficients for level ``l``, +with :math:`n` levels as a tuple ``(A, Tn, ..., T1)`` of length :math:`n + 1`. +``A`` denotes a tensor of approximation coefficients for the `n`-th level +of decomposition. ``Tl`` is a tuple of detail coefficients for level ``l``, see :data:`ptwt.constants.WaveletDetailTuple2d`. Note that this type always contains an approximation coefficient tensor but does not @@ -103,9 +104,9 @@ class WaveletDetailTuple2d(NamedTuple): """Type alias for wavelet transform results in any dimension. This type alias represents the result of a Nd wavelet transform -with :math:`L` levels as a tuple ``(A, D1, D2, ...)`` of length :math:`L + 1` -where ``A`` denotes a tensor of approximation coefficients and -``Dl`` is a dictionary of detail coefficients for level ``l``, +with :math:`n` levels as a tuple ``(A, Dn, ..., D1)`` of length :math:`n + 1`. +``A`` denotes a tensor of approximation coefficients for the `n`-th level +of decomposition. ``Dl`` is a dictionary of detail coefficients for level ``l``, see :data:`ptwt.constants.WaveletDetailDict`. Note that this type always contains an approximation coefficient tensor but does not From 942ad606a2bf6bdf6df800aced70035fc70d9770 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 15:45:22 +0200 Subject: [PATCH 37/40] Improve typing in JIT code --- examples/speed_tests/timeitconv_1d.py | 24 ++---------- examples/speed_tests/timeitconv_2d.py | 24 ++---------- .../speed_tests/timeitconv_2d_separable.py | 22 +---------- examples/speed_tests/timeitconv_3d.py | 20 +--------- src/ptwt/__init__.py | 4 +- src/ptwt/_util.py | 38 ++++++++++++++++++- tests/test_jit.py | 26 ++----------- 7 files changed, 52 insertions(+), 106 deletions(-) diff --git a/examples/speed_tests/timeitconv_1d.py b/examples/speed_tests/timeitconv_1d.py index 135f77c9..18b317df 100644 --- a/examples/speed_tests/timeitconv_1d.py +++ b/examples/speed_tests/timeitconv_1d.py @@ -9,26 +9,8 @@ import ptwt -class WaveletTuple(NamedTuple): - """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi")).""" - - dec_lo: torch.Tensor - dec_hi: torch.Tensor - rec_lo: torch.Tensor - rec_hi: torch.Tensor - - -def _set_up_wavelet_tuple(wavelet, dtype): - return WaveletTuple( - torch.tensor(wavelet.dec_lo).type(dtype), - torch.tensor(wavelet.dec_hi).type(dtype), - torch.tensor(wavelet.rec_lo).type(dtype), - torch.tensor(wavelet.rec_hi).type(dtype), - ) - - def _jit_wavedec_fun(data, wavelet): - return ptwt.wavedec(data, wavelet, "periodic", level=10) + return ptwt.wavedec(data, wavelet, mode="periodic", level=10) if __name__ == "__main__": @@ -56,7 +38,7 @@ def _jit_wavedec_fun(data, wavelet): end = time.perf_counter() ptwt_time_cpu.append(end - start) - wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32) jit_wavedec = torch.jit.trace( _jit_wavedec_fun, (data, wavelet), @@ -81,7 +63,7 @@ def _jit_wavedec_fun(data, wavelet): end = time.perf_counter() ptwt_time_gpu.append(end - start) - wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32) jit_wavedec = torch.jit.trace( _jit_wavedec_fun, (data.cuda(), wavelet), diff --git a/examples/speed_tests/timeitconv_2d.py b/examples/speed_tests/timeitconv_2d.py index cada2ed0..f0e6c5dd 100644 --- a/examples/speed_tests/timeitconv_2d.py +++ b/examples/speed_tests/timeitconv_2d.py @@ -9,27 +9,9 @@ import ptwt -class WaveletTuple(NamedTuple): - """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi")).""" - - dec_lo: torch.Tensor - dec_hi: torch.Tensor - rec_lo: torch.Tensor - rec_hi: torch.Tensor - - -def _set_up_wavelet_tuple(wavelet, dtype): - return WaveletTuple( - torch.tensor(wavelet.dec_lo).type(dtype), - torch.tensor(wavelet.dec_hi).type(dtype), - torch.tensor(wavelet.rec_lo).type(dtype), - torch.tensor(wavelet.rec_hi).type(dtype), - ) - - -def _to_jit_wavedec_2(data, wavelet): +def _to_jit_wavedec_2(data: torch.Tensor, wavelet) -> list[torch.Tensor]: """Ensure uniform datatypes in lists for the tracer. - Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor] + Going from list[Union[torch.Tensor, list[torch.Tensor]]] to list[torch.Tensor] means we have to stack the lists in the output. """ assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing." @@ -79,7 +61,7 @@ def _to_jit_wavedec_2(data, wavelet): ptwt_time_gpu.append(end - start) - wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32) jit_wavedec = torch.jit.trace( _to_jit_wavedec_2, (data.cuda(), wavelet), diff --git a/examples/speed_tests/timeitconv_2d_separable.py b/examples/speed_tests/timeitconv_2d_separable.py index 770de1d2..98cdc557 100644 --- a/examples/speed_tests/timeitconv_2d_separable.py +++ b/examples/speed_tests/timeitconv_2d_separable.py @@ -10,31 +10,13 @@ import ptwt -class WaveletTuple(NamedTuple): - """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi")).""" - - dec_lo: torch.Tensor - dec_hi: torch.Tensor - rec_lo: torch.Tensor - rec_hi: torch.Tensor - - -def _set_up_wavelet_tuple(wavelet, dtype): - return WaveletTuple( - torch.tensor(wavelet.dec_lo).type(dtype), - torch.tensor(wavelet.dec_hi).type(dtype), - torch.tensor(wavelet.rec_lo).type(dtype), - torch.tensor(wavelet.rec_hi).type(dtype), - ) - - def _to_jit_wavedec_2(data, wavelet): """Ensure uniform datatypes in lists for the tracer. Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor] means we have to stack the lists in the output. """ assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing." - coeff = ptwt.fswavedec2(data, wavelet, "reflect", level=5) + coeff = ptwt.fswavedec2(data, wavelet, mode="reflect", level=5) coeff2 = [] for c in coeff: if isinstance(c, torch.Tensor): @@ -103,7 +85,7 @@ def _to_jit_wavedec_2(data, wavelet): end = time.perf_counter() ptwt_time_gpu.append(end - start) - wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32) jit_wavedec = torch.jit.trace( _to_jit_wavedec_2, (data.cuda(), wavelet), diff --git a/examples/speed_tests/timeitconv_3d.py b/examples/speed_tests/timeitconv_3d.py index af47a0c8..8a933186 100644 --- a/examples/speed_tests/timeitconv_3d.py +++ b/examples/speed_tests/timeitconv_3d.py @@ -9,24 +9,6 @@ import ptwt -class WaveletTuple(NamedTuple): - """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi")).""" - - dec_lo: torch.Tensor - dec_hi: torch.Tensor - rec_lo: torch.Tensor - rec_hi: torch.Tensor - - -def _set_up_wavelet_tuple(wavelet, dtype): - return WaveletTuple( - torch.tensor(wavelet.dec_lo).type(dtype), - torch.tensor(wavelet.dec_hi).type(dtype), - torch.tensor(wavelet.rec_lo).type(dtype), - torch.tensor(wavelet.rec_hi).type(dtype), - ) - - def _to_jit_wavedec_3(data, wavelet): """Ensure uniform datatypes in lists for the tracer. @@ -85,7 +67,7 @@ def _to_jit_wavedec_3(data, wavelet): end = time.perf_counter() ptwt_time_gpu.append(end - start) - wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32) jit_wavedec = torch.jit.trace( _to_jit_wavedec_3, (data.cuda(), wavelet), diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index a8d8840f..69752e5c 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -1,7 +1,7 @@ """Differentiable and gpu enabled fast wavelet transforms in PyTorch.""" -from ._util import Wavelet -from .constants import WaveletCoeff2d, WaveletCoeffNd, WaveletCoeff2dSeparable +from ._util import Wavelet, WaveletTensorTuple +from .constants import WaveletCoeff2d, WaveletCoeff2dSeparable, WaveletCoeffNd from .continuous_transform import cwt from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 446eb63e..414323f1 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -4,7 +4,7 @@ import typing from collections.abc import Sequence -from typing import Any, Callable, Optional, Protocol, Union, cast, overload +from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np import pywt @@ -38,6 +38,42 @@ def __len__(self) -> int: return len(self.dec_lo) +class WaveletTensorTuple(NamedTuple): + """Named tuple containing the wavelet filter bank to use in JIT code.""" + + dec_lo: torch.Tensor + dec_hi: torch.Tensor + rec_lo: torch.Tensor + rec_hi: torch.Tensor + + @property + def dec_len(self) -> int: + """Length of decomposition filters.""" + return len(self.dec_lo) + + @property + def rec_len(self) -> int: + """Length of reconstruction filters.""" + return len(self.rec_lo) + + @property + def filter_bank( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Filter bank of the wavelet.""" + return self + + @classmethod + def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> WaveletTensorTuple: + """Construct Wavelet named tuple from wavelet protocol member.""" + return cls( + torch.tensor(wavelet.dec_lo, dtype=dtype), + torch.tensor(wavelet.dec_hi, dtype=dtype), + torch.tensor(wavelet.rec_lo, dtype=dtype), + torch.tensor(wavelet.rec_hi, dtype=dtype), + ) + + def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet: """Ensure the input argument to be a pywt wavelet compatible object. diff --git a/tests/test_jit.py b/tests/test_jit.py index c06eda52..06b4db15 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,6 +1,6 @@ """Ensure pytorch's torch.jit.trace feature works properly.""" -from typing import NamedTuple, Optional, Union +from typing import Optional, Union import numpy as np import pytest @@ -13,24 +13,6 @@ from tests._mackey_glass import MackeyGenerator -class WaveletTuple(NamedTuple): - """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi")).""" - - dec_lo: torch.Tensor - dec_hi: torch.Tensor - rec_lo: torch.Tensor - rec_hi: torch.Tensor - - -def _set_up_wavelet_tuple(wavelet: WaveletTuple, dtype: torch.dtype) -> WaveletTuple: - return WaveletTuple( - torch.tensor(wavelet.dec_lo).type(dtype), - torch.tensor(wavelet.dec_hi).type(dtype), - torch.tensor(wavelet.rec_lo).type(dtype), - torch.tensor(wavelet.rec_hi).type(dtype), - ) - - def _to_jit_wavedec_fun( data: torch.Tensor, wavelet: Union[ptwt.Wavelet, str], level: Optional[int] ) -> list[torch.Tensor]: @@ -53,7 +35,7 @@ def test_conv_fwt_jit( mackey_data_1 = torch.squeeze(generator(), -1).type(dtype) wavelet = pywt.Wavelet(wavelet_string) - wavelet = _set_up_wavelet_tuple(wavelet, dtype) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype) with pytest.warns(Warning): jit_wavedec = torch.jit.trace( # type: ignore @@ -105,7 +87,7 @@ def test_conv_fwt_jit_2d() -> None: rec = _to_jit_waverec_2(coeff, wavelet) assert np.allclose(rec.squeeze(1).numpy(), data.numpy()) - wavelet = _set_up_wavelet_tuple(wavelet, dtype=torch.float64) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype=torch.float64) with pytest.warns(Warning): jit_wavedec2 = torch.jit.trace( # type: ignore _to_jit_wavedec_2, @@ -159,7 +141,7 @@ def test_conv_fwt_jit_3d() -> None: rec = _to_jit_waverec_3(coeff, wavelet) assert np.allclose(rec.squeeze(1).numpy(), data.numpy()) - wavelet = _set_up_wavelet_tuple(wavelet, dtype=torch.float64) + wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype=torch.float64) with pytest.warns(Warning): jit_wavedec3 = torch.jit.trace( # type: ignore _to_jit_wavedec_3, From b8a510a926ba85291f4621b0ac797d59617b8262 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 15:47:12 +0200 Subject: [PATCH 38/40] Adapt right pad logic to avoid JIT tracer warning --- src/ptwt/conv_transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index f4898961..de5876b0 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -100,8 +100,7 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: padl = (2 * filt_len - 3) // 2 # pad to even singal length. - if data_len % 2 != 0: - padr += 1 + padr += data_len % 2 return padr, padl From 0ac20c518e2cbf1ed81441d124dfbf2077cdcaa0 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Fri, 14 Jun 2024 16:12:03 +0200 Subject: [PATCH 39/40] Refactor of order methods for WaveletPacket2d Make the functions creating the natural and frequency packet order for WaveletPacket2d static methods of WaveletPacket2d. This changes * `get_natural_order` from instance to static function * `get_frequency_order` from separate func to static function in the scope of WaveletPacket2d Further, to make both methods consistent, `get_frequency_order` now returns concatenated strings instead of a tuple of single char strings. --- examples/deepfake_analysis/packet_plot.py | 47 +------- src/ptwt/packets.py | 127 +++++++++++----------- tests/test_packets.py | 4 +- 3 files changed, 71 insertions(+), 107 deletions(-) diff --git a/examples/deepfake_analysis/packet_plot.py b/examples/deepfake_analysis/packet_plot.py index 3248d61e..83123f6e 100644 --- a/examples/deepfake_analysis/packet_plot.py +++ b/examples/deepfake_analysis/packet_plot.py @@ -11,47 +11,6 @@ import ptwt -def get_freq_order(level: int): - """Get the frequency order for a given packet decomposition level. - Adapted from: - https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py - The code elements denote the filter application order. The filters - are named following the pywt convention as: - a - LL, low-low coefficients - h - LH, low-high coefficients - v - HL, high-low coefficients - d - HH, high-high coefficients - """ - wp_natural_path = list(product(["a", "h", "v", "d"], repeat=level)) - - def _get_graycode_order(level, x="a", y="d"): - graycode_order = [x, y] - for _ in range(level - 1): - graycode_order = [x + path for path in graycode_order] + [ - y + path for path in graycode_order[::-1] - ] - return graycode_order - - def _expand_2d_path(path): - expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"} - return ( - "".join([expanded_paths[p][0] for p in path]), - "".join([expanded_paths[p][1] for p in path]), - ) - - nodes: dict = {} - for (row_path, col_path), node in [ - (_expand_2d_path(node), node) for node in wp_natural_path - ]: - nodes.setdefault(row_path, {})[col_path] = node - graycode_order = _get_graycode_order(level, x="l", y="h") - nodes_list: list = [nodes[path] for path in graycode_order if path in nodes] - wp_frequency_path = [] - for row in nodes_list: - wp_frequency_path.append([row[path] for path in graycode_order if path in row]) - return wp_frequency_path, wp_natural_path - - def generate_frequency_packet_image(packet_array: np.ndarray, degree: int): """Create a ready-to-polt image with frequency-order packages. Given a packet array in natural order, creat an image which is @@ -63,7 +22,8 @@ def generate_frequency_packet_image(packet_array: np.ndarray, degree: int): Returns: [np.ndarray]: The image of shape [original_height, original_width] """ - wp_freq_path, wp_natural_path = get_freq_order(degree) + wp_freq_path = ptwt.WaveletPacket2D.get_freq_order(degree) + wp_natural_path = ptwt.WaveletPacket2D.get_natural_order(degree) image = [] # go through the rows. @@ -107,7 +67,8 @@ def load_images(path: str) -> list: if __name__ == "__main__": - frequency_path, natural_path = get_freq_order(level=3) + freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3) + frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3) print("Loading ffhq images:") ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq") print("processing ffhq") diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 33b984de..879f5eff 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -122,7 +122,7 @@ def __init__( def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None - ) -> "WaveletPacket": + ) -> WaveletPacket: """Calculate the 1d wavelet packet transform for the input data. Args: @@ -139,7 +139,7 @@ def transform( self._recursive_dwt(data, level=0, path="") return self - def reconstruct(self) -> "WaveletPacket": + def reconstruct(self) -> WaveletPacket: """Recursively reconstruct the input starting from the leaf nodes. Reconstruction replaces the input data originally assigned to this object. @@ -326,7 +326,7 @@ def __init__( def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None - ) -> "WaveletPacket2D": + ) -> WaveletPacket2D: """Calculate the 2d wavelet packet transform for the input data. The transform function allows reusing the same object. @@ -350,7 +350,7 @@ def transform( self._recursive_dwt2d(data, level=0, path="") return self - def reconstruct(self) -> "WaveletPacket2D": + def reconstruct(self) -> WaveletPacket2D: """Recursively reconstruct the input starting from the leaf nodes. Note: @@ -364,7 +364,7 @@ def reconstruct(self) -> "WaveletPacket2D": ) for level in reversed(range(self.maxlevel)): - for node in self.get_natural_order(level): + for node in WaveletPacket2D.get_natural_order(level): data_a = self[node + "a"] data_h = self[node + "h"] data_v = self[node + "v"] @@ -386,17 +386,6 @@ def reconstruct(self) -> "WaveletPacket2D": self[node] = rec return self - def get_natural_order(self, level: int) -> list[str]: - """Get the natural ordering for a given decomposition level. - - Args: - level (int): The decomposition level. - - Returns: - A list with the filter order strings. - """ - return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] - def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ [torch.Tensor], WaveletCoeff2d, @@ -525,54 +514,68 @@ def __getitem__(self, key: str) -> torch.Tensor: ) return super().__getitem__(key) + @staticmethod + def get_natural_order(level: int) -> list[str]: + """Get the natural ordering for a given decomposition level. -def get_freq_order(level: int) -> list[list[tuple[str, ...]]]: - """Get the frequency order for a given packet decomposition level. + Args: + level (int): The decomposition level. - Use this code to create two-dimensional frequency orderings. + Returns: + A list with the filter order strings. + """ + return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] - Args: - level (int): The number of decomposition scales. + @staticmethod + def get_freq_order(level: int) -> list[list[str]]: + """Get the frequency order for a given packet decomposition level. - Returns: - A list with the tree nodes in frequency order. - - Note: - Adapted from: - https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py - - The code elements denote the filter application order. The filters - are named following the pywt convention as: - a - LL, low-low coefficients - h - LH, low-high coefficients - v - HL, high-low coefficients - d - HH, high-high coefficients - """ - wp_natural_path = product(["a", "h", "v", "d"], repeat=level) + Use this code to create two-dimensional frequency orderings. - def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: - graycode_order = [x, y] - for _ in range(level - 1): - graycode_order = [x + path for path in graycode_order] + [ - y + path for path in graycode_order[::-1] - ] - return graycode_order - - def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]: - expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"} - return ( - "".join([expanded_paths[p][0] for p in path]), - "".join([expanded_paths[p][1] for p in path]), - ) - - nodes_dict: dict[str, dict[str, tuple[str, ...]]] = {} - for (row_path, col_path), node in [ - (_expand_2d_path(node), node) for node in wp_natural_path - ]: - nodes_dict.setdefault(row_path, {})[col_path] = node - graycode_order = _get_graycode_order(level, x="l", y="h") - nodes = [nodes_dict[path] for path in graycode_order if path in nodes_dict] - result = [] - for row in nodes: - result.append([row[path] for path in graycode_order if path in row]) - return result + Args: + level (int): The number of decomposition scales. + + Returns: + A list with the tree nodes in frequency order. + + Note: + Adapted from: + https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py + + The code elements denote the filter application order. The filters + are named following the pywt convention as: + a - LL, low-low coefficients + h - LH, low-high coefficients + v - HL, high-low coefficients + d - HH, high-high coefficients + """ + wp_natural_path = product(["a", "h", "v", "d"], repeat=level) + + def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: + graycode_order = [x, y] + for _ in range(level - 1): + graycode_order = [x + path for path in graycode_order] + [ + y + path for path in graycode_order[::-1] + ] + return graycode_order + + def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]: + expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"} + return ( + "".join([expanded_paths[p][0] for p in path]), + "".join([expanded_paths[p][1] for p in path]), + ) + + nodes_dict: dict[str, dict[str, tuple[str, ...]]] = {} + for (row_path, col_path), node in [ + (_expand_2d_path(node), node) for node in wp_natural_path + ]: + nodes_dict.setdefault(row_path, {})[col_path] = node + graycode_order = _get_graycode_order(level, x="l", y="h") + nodes = [nodes_dict[path] for path in graycode_order if path in nodes_dict] + result = [] + for row in nodes: + result.append( + ["".join(row[path]) for path in graycode_order if path in row] + ) + return result diff --git a/tests/test_packets.py b/tests/test_packets.py index 3e748e43..00cf6c4b 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -12,7 +12,7 @@ from scipy import datasets from ptwt.constants import ExtendedBoundaryMode -from ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order +from ptwt.packets import WaveletPacket, WaveletPacket2D def _compare_trees1( @@ -236,7 +236,7 @@ def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None: ) # Get the full decomposition freq_tree = wp_tree.get_level(level, "freq") - freq_order = get_freq_order(level) + freq_order = WaveletPacket2D.get_freq_order(level) for order_list, tree_list in zip(freq_tree, freq_order): for order_el, tree_el in zip(order_list, tree_list): From c97335046d2104f32c65139341f65b7fc41f3cea Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 18 Jun 2024 14:05:11 +0200 Subject: [PATCH 40/40] Use mypy main branch for typing session for now --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 1f4da72e..770369d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,7 +67,7 @@ tests = # pooch is an optional scipy dependency for getting datasets pooch typing = - mypy + mypy @ git+https://github.com/python/mypy # needed otherwise pytest decorators don't get typed properly pytest examples =