From 57c1ab7f147b26532b14b0b65728733ea2afdf63 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 4 Jun 2024 00:16:36 +0200 Subject: [PATCH] Format --- src/ptwt/_util.py | 30 ++++++++++++++-------------- src/ptwt/conv_transform.py | 1 + src/ptwt/conv_transform_2.py | 4 +++- src/ptwt/conv_transform_3.py | 5 +---- src/ptwt/matmul_transform_2.py | 4 +--- src/ptwt/matmul_transform_3.py | 8 ++------ src/ptwt/packets.py | 18 +++++++++++------ src/ptwt/separable_conv_transform.py | 2 +- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index d8d597cf..951d340a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,13 +1,13 @@ """Utility methods to compute wavelet decompositions from a dataset.""" -from collections.abc import Sequence import typing +from collections.abc import Sequence from typing import Any, Callable, Optional, Protocol, Union, cast, overload -from typing_extensions import Unpack import numpy as np import pywt import torch +from typing_extensions import Unpack from .constants import OrthogonalizeMethod @@ -30,10 +30,13 @@ def __len__(self) -> int: """Return the number of filter coefficients.""" return len(self.dec_lo) + WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor] WaveletDetailDict = dict[str, torch.Tensor] -WaveletCoeffDetailTuple2d = tuple[torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]] +WaveletCoeffDetailTuple2d = tuple[ + torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]] +] WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]] @@ -96,7 +99,7 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch. cat_list.insert(0, signal[:padl].flip(0)) if padr > 0: cat_list.append(signal[-padr::].flip(0)) - return torch.cat(cat_list, axis=0) # type: ignore + return torch.cat(cat_list, dim=0) def _pad_symmetric( @@ -118,7 +121,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int """Fold unchanged leading dimensions into a single batch dimension. Args: - data ( torch.Tensor): The input data array. + data (torch.Tensor): The input data array. keep_no (int): The number of dimensions to keep. Returns: @@ -176,21 +179,19 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: def _map_result( data: WaveletCoeffDetailTuple2d, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailTuple2d: - ... +) -> WaveletCoeffDetailTuple2d: ... @overload def _map_result( data: WaveletCoeffDetailDict, function: Callable[[torch.Tensor], torch.Tensor], -) -> WaveletCoeffDetailDict: - ... +) -> WaveletCoeffDetailDict: ... def _map_result( data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], - function: Callable[[torch.Tensor], torch.Tensor] + function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]: return_tuple = isinstance(data, tuple) approx = function(data[0]) @@ -210,14 +211,13 @@ def _map_result( ) ) elif isinstance(element, dict): - new_dict = { - key: function(value) - for key, value in element.items() - } + new_dict = {key: function(value) for key, value in element.items()} result_lst.append(new_dict) else: raise AssertionError(f"Unexpected input type {type(element)}") return_val = approx, *result_lst - return_val = cast(Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val) + return_val = cast( + Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val + ) return return_val diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 66c9798e..77a39776 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -2,6 +2,7 @@ This module treats boundaries with edge-padding. """ + from collections.abc import Sequence from typing import Optional, Union, cast diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 229c910f..0b737129 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -97,7 +97,9 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list(coeffs: WaveletCoeffDetailTuple2d) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: +def _waverec2d_fold_channels_2d_list( + coeffs: WaveletCoeffDetailTuple2d, +) -> tuple[WaveletCoeffDetailTuple2d, list[int]]: # fold the input coefficients for processing conv2d_transpose. ds = list(_check_if_tensor(coeffs[0]).shape) return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index a9debf59..6883b00e 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -222,10 +222,7 @@ def _waverec3d_fold_channels_3d_list( fold_coeffs: list[dict[str, torch.Tensor]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) fold_coeffs = [ - { - key: _fold_axes(value, 3)[0] - for key, value in coeff.items() - } + {key: _fold_axes(value, 3)[0] for key, value in coeff.items()} for coeff in coeffs[1:] ] return (fold_approx_coeff, *fold_coeffs), ds diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index f1ff9157..fc596953 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -416,9 +416,7 @@ def _construct_analysis_matrices( current_width = current_width // 2 self.size_list.append((current_height, current_width)) - def __call__( - self, input_signal: torch.Tensor - ) -> WaveletCoeffDetailTuple2d: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index a71fa524..18736915 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -156,9 +156,7 @@ def _construct_analysis_matrices( ) self.size_list.append((current_depth, current_height, current_width)) - def __call__( - self, input_signal: torch.Tensor - ) -> WaveletCoeffDetailDict: + def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict: """Compute a separable 3d-boundary wavelet transform. Args: @@ -388,9 +386,7 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten return cat_tensor return self._cat_coeff_recursive(done_dict) - def __call__( - self, coefficients: WaveletCoeffDetailDict - ) -> torch.Tensor: + def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 3376b88e..868ffe07 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -10,7 +10,12 @@ import pywt import torch -from ._util import Wavelet, WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, _as_wavelet +from ._util import ( + Wavelet, + WaveletCoeffDetailDict, + WaveletCoeffDetailTuple2d, + _as_wavelet, +) from .constants import ExtendedBoundaryMode, OrthogonalizeMethod from .conv_transform import wavedec, waverec from .conv_transform_2 import wavedec2, waverec2 @@ -384,7 +389,8 @@ def get_natural_order(self, level: int) -> list[str]: return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ - [torch.Tensor], WaveletCoeffDetailTuple2d, + [torch.Tensor], + WaveletCoeffDetailTuple2d, ]: if self.mode == "boundary": shape = tuple(shape) @@ -413,7 +419,9 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: + def _get_waverec( + self, shape: tuple[int, ...] + ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: if self.mode == "boundary": shape = tuple(shape) if shape not in self.matrix_waverec2_dict.keys(): @@ -450,9 +458,7 @@ def _transform_tuple_to_fsdict_func( self, fsdict_func: Callable[[WaveletCoeffDetailDict], torch.Tensor], ) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]: - def _fsdict_func( - coeffs: WaveletCoeffDetailTuple2d - ) -> torch.Tensor: + def _fsdict_func(coeffs: WaveletCoeffDetailTuple2d) -> torch.Tensor: # assert for type checking assert len(coeffs) == 2 a, (h, v, d) = coeffs diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 63d7e142..bde51f71 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -16,8 +16,8 @@ from ._util import ( Wavelet, - WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, + WaveletCoeffDetailTuple2d, _as_wavelet, _check_axes_argument, _check_if_tensor,