From 1c9778cdf24079180cbdadfd4c5122c0a0c53cb6 Mon Sep 17 00:00:00 2001 From: loki-veera Date: Thu, 21 Sep 2023 17:23:13 +0200 Subject: [PATCH] add axes support for conv2d along with tests --- src/ptwt/_util.py | 13 ++-- src/ptwt/conv_transform_2.py | 117 ++++++++++++++++++++++++---------- tests/test_convolution_fwt.py | 71 +++++++++++++++++++++ 3 files changed, 163 insertions(+), 38 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 6bfdd5bd..ec1f9737 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -134,11 +134,14 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int The folded result array, and the shape of the original input. """ dshape = list(data.shape) - return torch.reshape(data, [np.prod(dshape[:-keep_no])] + dshape[-keep_no:]), dshape + return ( + torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]), + dshape, + ) def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor: - """Unfold i.e. [batch*channel, height, widht] into [batch, channel, height, width].""" + """Unfold [batch*channel, height, widht] into [batch, channel, height, width].""" return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:])) @@ -167,11 +170,11 @@ def _get_transpose_order( def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: _check_axes_argument(axes) front, back = _get_transpose_order(axes, list(data.shape)) - return torch.transpose(data, front + back) + return torch.permute(data, front + back) def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: _check_axes_argument(axes) front, back = _get_transpose_order(axes, list(data.shape)) - restore_sorted = torch.argsort(torch.tensor(front + back)) - return torch.transpose(data, restore_sorted) + restore_sorted = torch.argsort(torch.tensor(front + back)).tolist() + return torch.permute(data, restore_sorted) diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 7b57a967..e4c889fb 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -5,7 +5,8 @@ """ -from typing import Any, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union import pywt import torch @@ -13,11 +14,15 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_channels, + _check_axes_argument, + _fold_axes, _get_len, _is_dtype_supported, _outer, _pad_symmetric, + _swap_axes, + _undo_swap_axes, + _unfold_axes, _unfold_channels, ) from .conv_transform import ( @@ -122,23 +127,42 @@ def _waverec2d_fold_channels_2d_list( ds = list(_check_if_tensor(coeffs[0]).shape) for coeff in coeffs: if isinstance(coeff, torch.Tensor): - fold_coeffs.append(_fold_channels(coeff)) + fold_coeffs.append(_fold_axes(coeff, 2)[0]) else: fold_coeffs.append( ( - _fold_channels(coeff[0]), - _fold_channels(coeff[1]), - _fold_channels(coeff[2]), + _fold_axes(coeff[0], 2)[0], + _fold_axes(coeff[1], 2)[0], + _fold_axes(coeff[2], 2)[0], ) ) return fold_coeffs, ds +def _preprocess_tensor_dec2d( + data: torch.Tensor, +) -> Tuple[torch.Tensor, Union[List[int], None]]: + # Preprocess multidimensional input. + ds = None + if len(data.shape) == 2: + data = data.unsqueeze(0).unsqueeze(0) + elif len(data.shape) == 3: + # add a channel dimension for torch. + data = data.unsqueeze(1) + elif len(data.shape) >= 4: + data, ds = _fold_axes(data, 2) + data = data.unsqueeze(1) + elif len(data.shape) == 1: + raise ValueError("More than one input dimension required.") + return data, ds + + def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect", level: Optional[int] = None, + axes: Tuple[int, int] = (-2, -1), ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Non-separated two-dimensional wavelet transform. Only the last two axes change. @@ -157,6 +181,8 @@ def wavedec2( This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. + axes (Tuple[int, int]): Compute the transform over these axes instead of the + last two. Defaults to (-2, -1). Returns: list: A list containing the wavelet coefficients. @@ -183,29 +209,17 @@ def wavedec2( >>> level=2, mode="zero") """ - fold = False - if data.dim() == 2: - data = data.unsqueeze(0).unsqueeze(0) - elif data.dim() == 3: - # add a channel dimension for torch. - data = data.unsqueeze(1) - elif data.dim() == 4: - # avoid the channel sum, fold the channels into batches. - fold = True - ds = data.shape - data = _fold_channels(data).unsqueeze(1) - elif data.dim() == 1: - raise ValueError("Wavedec2 needs more than one input dimension to work.") - else: - raise ValueError( - "Wavedec2 does not support more than four input dimensions. \ - Optionally-batched two-dimensional inputs work." - ) - if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) + wavelet = _as_wavelet(wavelet) + data, ds = _preprocess_tensor_dec2d(data) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -225,11 +239,35 @@ def wavedec2( to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1)) result_lst.append(to_append) result_lst.append(res_ll.squeeze(1)) + result_lst.reverse() + + if ds: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + result_lst = _map_result(result_lst, _unfold_axes2) + + if axes != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + result_lst = _map_result(result_lst, undo_swap_fn) - if fold: - result_lst = _wavedec2d_unfold_channels_2d_list(result_lst, list(ds)) + return result_lst - return result_lst[::-1] + +def _map_result( + data: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], + function: Callable[[Any], torch.Tensor], +) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + # Apply the given function to the input list of tensor and tuples. + result_lst: List[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] + ] = [] + for element in data: + if isinstance(element, torch.Tensor): + result_lst.append(function(element)) + else: + result_lst.append( + (function(element[0]), function(element[1]), function(element[2])) + ) + return result_lst def _check_if_tensor(to_check: Any) -> torch.Tensor: @@ -245,6 +283,7 @@ def _check_if_tensor(to_check: Any) -> torch.Tensor: def waverec2( coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], wavelet: Union[Wavelet, str], + axes: Tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. @@ -258,6 +297,8 @@ def waverec2( and D diagonal coefficients. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axes (Tuple[int, int]): Compute the transform over these axes instead of the + last two. Defaults to (-2, -1). Returns: torch.Tensor: The reconstructed signal of shape ``[batch, height, width]`` or @@ -279,16 +320,23 @@ def waverec2( >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) """ + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + swap_fn = partial(_swap_axes, axes=list(axes)) + coeffs = _map_result(coeffs, swap_fn) + + ds = None wavelet = _as_wavelet(wavelet) res_ll = _check_if_tensor(coeffs[0]) torch_device = res_ll.device torch_dtype = res_ll.dtype - fold = False - if res_ll.dim() == 4: + if res_ll.dim() >= 4: # avoid the channel sum, fold the channels into batches. - fold = True coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs) res_ll = _check_if_tensor(coeffs[0]) @@ -348,7 +396,10 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - if fold: - res_ll = _unfold_channels(res_ll, list(ds)) + if ds: + res_ll = _unfold_axes(res_ll, list(ds), 2) + + if axes != (-2, -1): + res_ll = _undo_swap_axes(res_ll, list(axes)) return res_ll diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index c6787c6f..2b29bf23 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -251,3 +251,74 @@ def test_input_1d_dimension_error(): with pytest.raises(ValueError): data = torch.randn(50) wavedec2(data, "haar", 4) + + +def _compare_coeffs(ptwt_res, pywt_res): + test_list = [] + for ptwtcs, pywtcs in zip(ptwt_res, pywt_res): + if isinstance(ptwtcs, tuple): + test_list.extend( + tuple( + np.allclose(ptwtc.numpy(), pywtc) + for ptwtc, pywtc in zip(ptwtcs, pywtcs) + ) + ) + else: + test_list.append(np.allclose(ptwtcs.numpy(), pywtcs)) + return test_list + + +@pytest.mark.parametrize( + "size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)] +) +def test_multidim_input(size): + """Test the error for multi-dimensional inputs to wavedec2.""" + data = torch.randn(*size, dtype=torch.float64) + wavelet = "db2" + level = 3 + + pt_res = wavedec2(data, wavelet=wavelet, level=level, mode="reflect") + pywt_res = pywt.wavedec2(data.numpy(), wavelet=wavelet, level=level, mode="reflect") + rec = waverec2(pt_res, wavelet) + + # test coefficients + test_list = _compare_coeffs(pt_res, pywt_res) + assert all(test_list) + + # test reconstruction. + assert np.allclose( + data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]] + ) + + +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)]) +def test_axis_argument(axes): + """Ensure the axes argument works as expected.""" + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + + ptwt_coeff = wavedec2(data, "db2", level=3, mode="reflect", axes=axes) + pywt_coeff = pywt.wavedec2(data, "db2", level=3, mode="reflect", axes=axes) + rec = waverec2(ptwt_coeff, "db2", axes=axes) + + # test coefficients + test_list = _compare_coeffs(ptwt_coeff, pywt_coeff) + assert all(test_list) + + # test reconstruction. + assert np.allclose( + data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]] + ) + + +def test_axis_error_axes_count(): + """Check the error for too many axes.""" + with pytest.raises(ValueError): + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + wavedec2(data, "haar", 1, axes=(1, 2, 3)) + + +def test_axis_error_axes_repetition(): + """Check the error for axes repetition.""" + with pytest.raises(ValueError): + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + wavedec2(data, "haar", 1, axes=(2, 2))