diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3cc20355..7403759f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,6 @@ name: Tests -on: [ push, pull_request ] +on: [ push ] jobs: tests: @@ -20,7 +20,7 @@ jobs: run: pip install nox - name: Test with pytest run: - nox -s test + nox -s fast-test lint: name: lint runs-on: ubuntu-latest diff --git a/README.rst b/README.rst index 642fcb25..6c81dc11 100644 --- a/README.rst +++ b/README.rst @@ -12,7 +12,7 @@ :alt: GitHub Actions .. image:: https://readthedocs.org/projects/pytorch-wavelet-toolbox/badge/?version=latest - :target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/?badge=latest + :target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/ptwt.html :alt: Documentation Status .. image:: https://img.shields.io/pypi/pyversions/ptwt diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 1912bc5c..6cd6cd1d 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,6 +1,7 @@ """Utility methods to compute wavelet decompositions from a dataset.""" -from typing import List, Optional, Protocol, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union +import numpy as np import pywt import torch @@ -103,19 +104,80 @@ def _pad_symmetric( return signal -def _fold_channels(data: torch.Tensor) -> torch.Tensor: - """Fold [batch, channel, height width] into [batch*channel, height, widht].""" - ds = data.shape - return torch.reshape( - data, - [ - ds[0] * ds[1], - ds[2], - ds[3], - ], +def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int]]: + """Fold unchanged leading dimensions into a single batch dimension. + + Args: + data ( torch.Tensor): The input data array. + keep_no (int): The number of dimensions to keep. + + Returns: + Tuple[ torch.Tensor, List[int]]: + The folded result array, and the shape of the original input. + """ + dshape = list(data.shape) + return ( + torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]), + dshape, ) -def _unfold_channels(data: torch.Tensor, ds: List[int]) -> torch.Tensor: - """Unfold [batch*channel, height, widht] into [batch, channel, height, width].""" - return torch.reshape(data, [ds[0], ds[1], data.shape[1], data.shape[2]]) +def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor: + """Unfold i.e. [batch*channel,height,widht] to [batch,channel,height,width].""" + return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:])) + + +def _check_if_tensor(array: Any) -> torch.Tensor: + if not isinstance(array, torch.Tensor): + raise ValueError( + "First element of coeffs must be the approximation coefficient tensor." + ) + return array + + +def _check_axes_argument(axes: List[int]) -> None: + if len(set(axes)) != len(axes): + raise ValueError("Cant transform the same axis twice.") + + +def _get_transpose_order( + axes: List[int], data_shape: List[int] +) -> Tuple[List[int], List[int]]: + axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes)) + all_axes = list(range(len(data_shape))) + remove_transformed = list(filter(lambda a: a not in axes, all_axes)) + return remove_transformed, axes + + +def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: + _check_axes_argument(axes) + front, back = _get_transpose_order(axes, list(data.shape)) + return torch.permute(data, front + back) + + +def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor: + _check_axes_argument(axes) + front, back = _get_transpose_order(axes, list(data.shape)) + restore_sorted = torch.argsort(torch.tensor(front + back)).tolist() + return torch.permute(data, restore_sorted) + + +def _map_result( + data: List[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any + function: Callable[[Any], torch.Tensor], +) -> List[Union[torch.Tensor, Any]]: + # Apply the given function to the input list of tensor and tuples. + result_lst: List[Union[torch.Tensor, Any]] = [] + for element in data: + if isinstance(element, torch.Tensor): + result_lst.append(function(element)) + elif isinstance(element, tuple): + result_lst.append( + (function(element[0]), function(element[1]), function(element[2])) + ) + elif isinstance(element, dict): + new_dict = {} + for key, value in element.items(): + new_dict[key] = function(value) + result_lst.append(new_dict) + return result_lst diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py index c5ca2f51..65285ef6 100644 --- a/src/ptwt/continuous_transform.py +++ b/src/ptwt/continuous_transform.py @@ -270,8 +270,8 @@ def wavefun( """Define a grid and evaluate the wavelet on it.""" length = 2**precision # load the bounds from untyped pywt code. - lower_bound: float = float(self.lower_bound) # type: ignore - upper_bound: float = float(self.upper_bound) # type: ignore + lower_bound: float = float(self.lower_bound) + upper_bound: float = float(self.upper_bound) grid = torch.linspace( lower_bound, upper_bound, diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 59b71f26..865fc82f 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -11,11 +11,11 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_channels, + _fold_axes, _get_len, _is_dtype_supported, _pad_symmetric, - _unfold_channels, + _unfold_axes, ) @@ -217,32 +217,52 @@ def _adjust_padding_at_reconstruction( return pad_end, pad_start -def _wavedec_fold_channels_1d(data: torch.Tensor) -> Tuple[torch.Tensor, List[int]]: - data = data.unsqueeze(-2) - ds = data.shape - data = _fold_channels(data) - return data, list(ds) +def _preprocess_tensor_dec1d( + data: torch.Tensor, +) -> Tuple[torch.Tensor, Union[List[int], None]]: + """Preprocess input tensor dimensions. + Args: + data (torch.Tensor): An input tensor of any shape. -def _wavedec_unfold_channels_1d_list( - result_list: List[torch.Tensor], ds: List[int] + Returns: + Tuple[torch.Tensor, Union[List[int], None]]: + A data tensor of shape [new_batch, 1, to_process] + and the original shape, if the shape has changed. + """ + ds = None + if len(data.shape) == 1: + # assume time series + data = data.unsqueeze(0).unsqueeze(0) + elif len(data.shape) == 2: + # assume batched time series + data = data.unsqueeze(1) + else: + data, ds = _fold_axes(data, 1) + data = data.unsqueeze(1) + return data, ds + + +def _postprocess_result_list_dec1d( + result_lst: List[torch.Tensor], ds: List[int] ) -> List[torch.Tensor]: - unfold_res = [] - for res_coeff in result_list: - unfold_res.append( - _unfold_channels(res_coeff.unsqueeze(1), list(ds)).squeeze(-2) - ) - return unfold_res + # Unfold axes for the wavelets + unfold_list = [] + for fres in result_lst: + unfold_list.append(_unfold_axes(fres, ds, 1)) + return unfold_list -def _waverec_fold_channels_1d_list( - coeff_list: List[torch.Tensor], +def _preprocess_result_list_rec1d( + result_lst: List[torch.Tensor], ) -> Tuple[List[torch.Tensor], List[int]]: - folded = [] - ds = coeff_list[0].unsqueeze(-2).shape - for to_fold_coeff in coeff_list: - folded.append(_fold_channels(to_fold_coeff.unsqueeze(-2)).squeeze(-2)) - return folded, list(ds) + # Fold axes for the wavelets + fold_coeffs = [] + ds = list(result_lst[0].shape) + for uf_coeff in result_lst: + f_coeff, _ = _fold_axes(uf_coeff, 1) + fold_coeffs.append(f_coeff) + return fold_coeffs, ds def wavedec( @@ -250,14 +270,13 @@ def wavedec( wavelet: Union[Wavelet, str], mode: str = "reflect", level: Optional[int] = None, + axis: int = -1, ) -> List[torch.Tensor]: """Compute the analysis (forward) 1d fast wavelet transform. Args: data (torch.Tensor): The input time series, - 1d inputs are interpreted as ``[time]``, - 2d inputs as ``[batch_size, time]``, - and 3d inputs as ``[batch_size, channels, time]``. + By default the last axis is transformed. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Please consider the output from ``pywt.wavelist(kind='discrete')`` @@ -274,9 +293,11 @@ def wavedec( Zero padding pads zeros. Constant padding replicates border values. Periodic padding cyclically repeats samples. - level (int): The scale level to be computed. Defaults to None. + axis (int): Compute the transform over this axis instead of the + last one. Defaults to -1. + Returns: list: A list:: @@ -287,7 +308,8 @@ def wavedec( approximation and D detail coefficients. Raises: - ValueError: If the dtype of the input data tensor is unsupported. + ValueError: If the dtype of the input data tensor is unsupported or + if more than one axis is provided. Example: >>> import torch @@ -300,21 +322,17 @@ def wavedec( >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'), >>> mode='zero', level=2) """ - fold = False - if data.dim() == 1: - # assume time series - data = data.unsqueeze(0).unsqueeze(0) - elif data.dim() == 2: - # assume batched time series - data = data.unsqueeze(1) - elif data.dim() == 3: - # assume batch, channels, time -> fold channels - fold = True - data, ds = _wavedec_fold_channels_1d(data) + if axis != -1: + if isinstance(axis, int): + data = data.swapaxes(axis, -1) + else: + raise ValueError("wavedec transforms a single axis only.") if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") + data, ds = _preprocess_tensor_dec1d(data) + dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -332,28 +350,38 @@ def wavedec( res_lo, res_hi = torch.split(res, 1, 1) result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) + result_list.reverse() + + if ds: + result_list = _postprocess_result_list_dec1d(result_list, ds) - # unfold if necessary - if fold: - result_list = _wavedec_unfold_channels_1d_list(result_list, ds) + if axis != -1: + swap = [] + for coeff in result_list: + swap.append(coeff.swapaxes(axis, -1)) + result_list = swap - return result_list[::-1] + return result_list -def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.Tensor: +def waverec( + coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1 +) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. Args: coeffs (list): The wavelet coefficient list produced by wavedec. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axis (int): Transform this axis instead of the last one. Defaults to -1. Returns: torch.Tensor: The reconstructed signal. Raises: ValueError: If the dtype of the coeffs tensor is unsupported or if the - coefficients have incompatible shapes, dtypes or devices. + coefficients have incompatible shapes, dtypes or devices or if + more than one axis is provided. Example: >>> import torch @@ -379,11 +407,19 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") + if axis != -1: + swap = [] + if isinstance(axis, int): + for coeff in coeffs: + swap.append(coeff.swapaxes(axis, -1)) + coeffs = swap + else: + raise ValueError("waverec transforms a single axis only.") + # fold channels, if necessary. - fold = False - if coeffs[0].dim() == 3: - fold = True - coeffs, ds = _waverec_fold_channels_1d_list(coeffs) + ds = None + if coeffs[0].dim() >= 3: + coeffs, ds = _preprocess_result_list_rec1d(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -408,7 +444,10 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T if padr > 0: res_lo = res_lo[..., :-padr] - if fold: - res_lo = _unfold_channels(res_lo.unsqueeze(-2), list(ds)).squeeze(-2) + if ds: + res_lo = _unfold_axes(res_lo, ds, 1) + + if axis != -1: + res_lo = res_lo.swapaxes(axis, -1) return res_lo diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index df1a3596..cb56011d 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -5,7 +5,8 @@ """ -from typing import Any, List, Optional, Tuple, Union +from functools import partial +from typing import List, Optional, Tuple, Union import pywt import torch @@ -13,12 +14,17 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_channels, + _check_axes_argument, + _check_if_tensor, + _fold_axes, _get_len, _is_dtype_supported, + _map_result, _outer, _pad_symmetric, - _unfold_channels, + _swap_axes, + _undo_swap_axes, + _unfold_axes, ) from .conv_transform import ( _adjust_padding_at_reconstruction, @@ -85,30 +91,6 @@ def _fwt_pad2( return data_pad -def _wavedec2d_unfold_channels_2d_list( - result_list: List[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ], - ds: List[int], -) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: - # unfolds the wavedec2d result lists, restoring the channel dimension. - unfold_res: List[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = [] - for cres in result_list: - if isinstance(cres, torch.Tensor): - unfold_res.append(_unfold_channels(cres, list(ds))) - else: - unfold_res.append( - ( - _unfold_channels(cres[0], list(ds)), - _unfold_channels(cres[1], list(ds)), - _unfold_channels(cres[2], list(ds)), - ) - ) - return unfold_res - - def _waverec2d_fold_channels_2d_list( coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ) -> Tuple[ @@ -116,22 +98,26 @@ def _waverec2d_fold_channels_2d_list( List[int], ]: # fold the input coefficients for processing conv2d_transpose. - fold_coeffs: List[ - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] = [] ds = list(_check_if_tensor(coeffs[0]).shape) - for coeff in coeffs: - if isinstance(coeff, torch.Tensor): - fold_coeffs.append(_fold_channels(coeff)) - else: - fold_coeffs.append( - ( - _fold_channels(coeff[0]), - _fold_channels(coeff[1]), - _fold_channels(coeff[2]), - ) - ) - return fold_coeffs, ds + return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds + + +def _preprocess_tensor_dec2d( + data: torch.Tensor, +) -> Tuple[torch.Tensor, Union[List[int], None]]: + # Preprocess multidimensional input. + ds = None + if len(data.shape) == 2: + data = data.unsqueeze(0).unsqueeze(0) + elif len(data.shape) == 3: + # add a channel dimension for torch. + data = data.unsqueeze(1) + elif len(data.shape) >= 4: + data, ds = _fold_axes(data, 2) + data = data.unsqueeze(1) + elif len(data.shape) == 1: + raise ValueError("More than one input dimension required.") + return data, ds def wavedec2( @@ -139,14 +125,16 @@ def wavedec2( wavelet: Union[Wavelet, str], mode: str = "reflect", level: Optional[int] = None, + axes: Tuple[int, int] = (-2, -1), ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Non-separated two-dimensional wavelet transform. Only the last two axes change. Args: - data (torch.Tensor): The input data tensor with up to three dimensions. - 2d inputs are interpreted as ``[height, width]``, + data (torch.Tensor): The input data tensor with any number of dimensions. + By default 2d inputs are interpreted as ``[height, width]``, 3d inputs are interpreted as ``[batch_size, height, width]``. 4d inputs are interpreted as ``[batch_size, channels, height, width]``. + the ``axis`` argument allows other interpretations. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. @@ -157,6 +145,8 @@ def wavedec2( This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. + axes (Tuple[int, int]): Compute the transform over these axes instead of the + last two. Defaults to (-2, -1). Returns: list: A list containing the wavelet coefficients. @@ -169,7 +159,7 @@ def wavedec2( Raises: ValueError: If the dimensionality or the dtype of the input data tensor - is unsupported. + is unsupported or if the provided axes input has length other than two. Example: >>> import torch @@ -183,29 +173,17 @@ def wavedec2( >>> level=2, mode="zero") """ - fold = False - if data.dim() == 2: - data = data.unsqueeze(0).unsqueeze(0) - elif data.dim() == 3: - # add a channel dimension for torch. - data = data.unsqueeze(1) - elif data.dim() == 4: - # avoid the channel sum, fold the channels into batches. - fold = True - ds = data.shape - data = _fold_channels(data).unsqueeze(1) - elif data.dim() == 1: - raise ValueError("Wavedec2 needs more than one input dimension to work.") - else: - raise ValueError( - "Wavedec2 does not support four input dimensions. \ - Optionally-batched two-dimensional inputs work." - ) - if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) + wavelet = _as_wavelet(wavelet) + data, ds = _preprocess_tensor_dec2d(data) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -225,26 +203,23 @@ def wavedec2( to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1)) result_lst.append(to_append) result_lst.append(res_ll.squeeze(1)) + result_lst.reverse() - if fold: - result_lst = _wavedec2d_unfold_channels_2d_list(result_lst, list(ds)) - - return result_lst[::-1] + if ds: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + result_lst = _map_result(result_lst, _unfold_axes2) + if axes != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + result_lst = _map_result(result_lst, undo_swap_fn) -def _check_if_tensor(to_check: Any) -> torch.Tensor: - # Ensuring the first list elements are tensors makes mypy happy :-). - if not isinstance(to_check, torch.Tensor): - raise ValueError( - "First element of coeffs must be the approximation coefficient tensor." - ) - else: - return to_check + return result_lst def waverec2( coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], wavelet: Union[Wavelet, str], + axes: Tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. @@ -258,6 +233,8 @@ def waverec2( and D diagonal coefficients. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axes (Tuple[int, int]): Compute the transform over these axes instead of the + last two. Defaults to (-2, -1). Returns: torch.Tensor: The reconstructed signal of shape ``[batch, height, width]`` or @@ -265,7 +242,8 @@ def waverec2( Raises: ValueError: If coeffs is not in a shape as returned from wavedec2 or - if the dtype is not supported. + if the dtype is not supported or if the provided axes input has length other + than two or if the same axes it repeated twice. Example: >>> import ptwt, pywt, torch @@ -279,16 +257,23 @@ def waverec2( >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) """ + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + swap_fn = partial(_swap_axes, axes=list(axes)) + coeffs = _map_result(coeffs, swap_fn) + + ds = None wavelet = _as_wavelet(wavelet) res_ll = _check_if_tensor(coeffs[0]) torch_device = res_ll.device torch_dtype = res_ll.dtype - fold = False - if res_ll.dim() == 4: + if res_ll.dim() >= 4: # avoid the channel sum, fold the channels into batches. - fold = True coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs) res_ll = _check_if_tensor(coeffs[0]) @@ -348,7 +333,10 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - if fold: - res_ll = _unfold_channels(res_ll, list(ds)) + if ds: + res_ll = _unfold_axes(res_ll, list(ds), 2) + + if axes != (-2, -1): + res_ll = _undo_swap_axes(res_ll, list(axes)) return res_ll diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 0ecbfaf8..31d8061d 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -3,7 +3,8 @@ The functions here are based on torch.nn.functional.conv3d and it's transpose. """ -from typing import Dict, List, Optional, Sequence, Union, cast +from functools import partial +from typing import Dict, List, Optional, Sequence, Tuple, Union, cast import pywt import torch @@ -11,10 +12,17 @@ from ._util import ( Wavelet, _as_wavelet, + _check_axes_argument, + _check_if_tensor, + _fold_axes, _get_len, _is_dtype_supported, + _map_result, _outer, _pad_symmetric, + _swap_axes, + _undo_swap_axes, + _unfold_axes, ) from .conv_transform import ( _adjust_padding_at_reconstruction, @@ -97,11 +105,12 @@ def wavedec3( wavelet: Union[Wavelet, str], mode: str = "zero", level: Optional[int] = None, + axes: Tuple[int, int, int] = (-3, -2, -1), ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a three-dimensional wavelet transform. Args: - data (torch.Tensor): The input data of shape + data (torch.Tensor): The input data. For example of shape [batch_size, length, height, width] wavelet (Union[Wavelet, str]): The wavelet to transform with. ``pywt.wavelist(kind='discrete')`` lists possible choices. @@ -112,6 +121,8 @@ def wavedec3( Defaults to "zero". level (Optional[int]): The maximum decomposition level. This argument defaults to None. + axes (Tuple[int, int, int]): Compute the transform over these axes + instead of the last three. Defaults to (-3, -2, -1). Returns: list: A list with the lll coefficients and dictionaries @@ -124,7 +135,8 @@ def wavedec3( Raises: ValueError: If the input has fewer than three dimensions or - if the dtype is not supported. + if the dtype is not supported or + if the provided axes input has length other than three. Example: >>> import ptwt, torch @@ -132,11 +144,21 @@ def wavedec3( >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") """ + if tuple(axes) != (-3, -2, -1): + if len(axes) != 3: + raise ValueError("3D transforms work with three axes.") + else: + _check_axes_argument(list(axes)) + data = _swap_axes(data, list(axes)) + + ds = None if data.dim() < 3: - raise ValueError("Three dimensional inputs required for 3d wavedec.") - elif data.dim() == 3: - # add batch dim. - data = data.unsqueeze(0) + raise ValueError("At least three dimensions are required for 3d wavedec.") + elif len(data.shape) == 3: + data = data.unsqueeze(1) + else: + data, ds = _fold_axes(data, 3) + data = data.unsqueeze(1) if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") @@ -155,7 +177,9 @@ def wavedec3( result_lst: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] res_lll = data for _ in range(level): - res_lll = _fwt_pad3(res_lll.unsqueeze(1), wavelet, mode=mode) + if len(res_lll.shape) == 4: + res_lll = res_lll.unsqueeze(1) + res_lll = _fwt_pad3(res_lll, wavelet, mode=mode) res = torch.nn.functional.conv3d(res_lll, dec_filt, stride=2) res_lll, res_llh, res_lhl, res_lhh, res_hll, res_hlh, res_hhl, res_hhh = [ sr.squeeze(1) for sr in torch.split(res, 1, 1) @@ -172,12 +196,40 @@ def wavedec3( } ) result_lst.append(res_lll) - return result_lst[::-1] + result_lst.reverse() + + if ds: + _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) + result_lst = _map_result(result_lst, _unfold_axes_fn) + + if tuple(axes) != (-3, -2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + result_lst = _map_result(result_lst, undo_swap_fn) + + return result_lst + + +def _waverec3d_fold_channels_3d_list( + coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], +) -> Tuple[List[Union[torch.Tensor, Dict[str, torch.Tensor]]], List[int],]: + # fold the input coefficients for processing conv2d_transpose. + fold_coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] + ds = list(_check_if_tensor(coeffs[0]).shape) + for coeff in coeffs: + if isinstance(coeff, torch.Tensor): + fold_coeffs.append(_fold_axes(coeff, 3)[0]) + else: + new_dict = {} + for key, value in coeff.items(): + new_dict[key] = _fold_axes(value, 3)[0] + fold_coeffs.append(new_dict) + return fold_coeffs, ds def waverec3( - coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[Wavelet, str], + axes: Tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Reconstruct a signal from wavelet coefficients. @@ -185,6 +237,8 @@ def waverec3( coeffs (list): The wavelet coefficient list produced by wavedec3. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axes (Tuple[int, int, int]): Transform these axes instead of the + last three. Defaults to (-3, -2, -1). Returns: torch.Tensor: The reconstructed four-dimensional signal of shape @@ -192,7 +246,8 @@ def waverec3( Raises: ValueError: If coeffs is not in a shape as returned from wavedec3 or - if the dtype is not supported. + if the dtype is not supported or if the provided axes input has length + other than three or if the same axes it repeated three. Example: >>> import ptwt, torch @@ -201,13 +256,25 @@ def waverec3( >>> reconstruction = ptwt.waverec3(transformed, "haar") """ + if tuple(axes) != (-3, -2, -1): + if len(axes) != 3: + raise ValueError("3D transforms work with three axes") + else: + _check_axes_argument(list(axes)) + swap_axes_fn = partial(_swap_axes, axes=list(axes)) + coeffs = _map_result(coeffs, swap_axes_fn) + wavelet = _as_wavelet(wavelet) + ds = None # the Union[tensor, dict] idea is coming from pywt. We don't change it here. - res_lll = coeffs[0] - if not isinstance(res_lll, torch.Tensor): + res_lll = _check_if_tensor(coeffs[0]) + if res_lll.dim() < 3: raise ValueError( - "First element of coeffs must be the approximation coefficient tensor." + "Three dimensional transforms require at least three dimensions." ) + elif res_lll.dim() >= 5: + coeffs, ds = _waverec3d_fold_channels_3d_list(coeffs) + res_lll = _check_if_tensor(coeffs[0]) torch_device = res_lll.device torch_dtype = res_lll.dtype @@ -284,4 +351,11 @@ def waverec3( res_lll = res_lll[..., padfr:, :, :] if padba > 0: res_lll = res_lll[..., :-padba, :, :] + res_lll = res_lll.squeeze(1) + + if ds: + res_lll = _unfold_axes(res_lll, ds, 3) + + if axes != (-3, -2, -1): + res_lll = _undo_swap_axes(res_lll, list(axes)) return res_lll diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 2c830730..a31191d9 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -18,13 +18,13 @@ _as_wavelet, _is_boundary_mode_supported, _is_dtype_supported, - _unfold_channels, + _unfold_axes, ) from .conv_transform import ( _get_filter_tensors, - _wavedec_fold_channels_1d, - _wavedec_unfold_channels_1d_list, - _waverec_fold_channels_1d_list, + _postprocess_result_list_dec1d, + _preprocess_result_list_rec1d, + _preprocess_tensor_dec1d, ) from .sparse_math import ( _orth_by_gram_schmidt, @@ -176,6 +176,7 @@ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, + axis: Optional[int] = -1, boundary: str = "qr", ) -> None: """Create a matrix-fwt object. @@ -186,6 +187,8 @@ def __init__( level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. + axis (int, optional): The axis we would like to transform. + Defaults to -1. boundary (str): The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' @@ -194,12 +197,18 @@ def __init__( Raises: NotImplementedError: If the selected `boundary` mode is not supported. - ValueError: If the wavelet filters have different lengths. + ValueError: If the wavelet filters have different lengths or + if axis is not an integer. """ self.wavelet = _as_wavelet(wavelet) self.level = level self.boundary = boundary + if isinstance(axis, int): + self.axis = axis + else: + raise ValueError("MatrixWavedec transforms a single axis only.") + self.input_length: Optional[int] = None self.fwt_matrix_list: List[torch.Tensor] = [] self.pad_list: List[bool] = [] @@ -304,10 +313,12 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: Matrix FWTs are used to avoid padding. Args: - input_signal (torch.Tensor): Batched input data ``[batch_size, time]``, - should be of even length. 1d inputs are interpreted as ``[time]``. - 3d inputs are treated as ``[batch, channels, time]``. - This transform only affects the last axis. + input_signal (torch.Tensor): Batched input data. + An example shape could be ``[batch_size, time]``. + Inputs can have any dimension. + This transform affects the last axis by default. + Use the axis argument in the constructor to choose + another axis. Returns: List[torch.Tensor]: A list with the coefficients for each scale. @@ -316,21 +327,11 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - fold = False - if input_signal.dim() == 1: - # assume time series - input_signal = input_signal.unsqueeze(0) - elif input_signal.dim() == 3: - # assume batch, channels, time -> fold channels - fold = True - input_signal, ds = _wavedec_fold_channels_1d(input_signal) - input_signal = input_signal.squeeze(1) - elif input_signal.dim() > 3: - raise ValueError( - f"Invalid input tensor shape {input_signal.size()}. " - "The input signal is expected to be of the form " - "[batch_size, (channels), length]." - ) + if self.axis != -1: + input_signal = input_signal.swapaxes(self.axis, -1) + + input_signal, ds = _preprocess_tensor_dec1d(input_signal) + input_signal = input_signal.squeeze(1) if not _is_dtype_supported(input_signal.dtype): raise ValueError(f"Input dtype {input_signal.dtype} not supported") @@ -373,8 +374,14 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]: result_list = [s.T for s in split_list[::-1]] # unfold if necessary - if fold: - result_list = _wavedec_unfold_channels_1d_list(result_list, ds) + if ds: + result_list = _postprocess_result_list_dec1d(result_list, ds) + + if self.axis != -1: + swap = [] + for coeff in result_list: + swap.append(coeff.swapaxes(self.axis, -1)) + result_list = swap return result_list @@ -455,15 +462,15 @@ class MatrixWaverec(object): """ def __init__( - self, - wavelet: Union[Wavelet, str], - boundary: str = "qr", + self, wavelet: Union[Wavelet, str], axis: int = -1, boundary: str = "qr" ) -> None: """Create the inverse matrix-based fast wavelet transformation. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axis (int): The axis transformed by the original decomposition + defaults to -1 or the last axis. boundary (str): The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' option @@ -472,10 +479,15 @@ def __init__( Raises: NotImplementedError: If the selected `boundary` mode is not supported. - ValueError: If the wavelet filters have different lengths. + ValueError: If the wavelet filters have different lengths or if + axis is not an integer. """ self.wavelet = _as_wavelet(wavelet) self.boundary = boundary + if isinstance(axis, int): + self.axis = axis + else: + raise ValueError("MatrixWaverec transforms a single axis only.") self.ifwt_matrix_list: List[torch.Tensor] = [] self.level: Optional[int] = None @@ -587,10 +599,15 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor: coefficients are not in the shape as it is returned from a `MatrixWavedec` object. """ - fold = False - if coefficients[0].dim() == 3: - fold = True - coefficients, ds = _waverec_fold_channels_1d_list(coefficients) + if self.axis != -1: + swap = [] + for coeff in coefficients: + swap.append(coeff.swapaxes(self.axis, -1)) + coefficients = swap + + ds = None + if coefficients[0].ndim > 2: + coefficients, ds = _preprocess_result_list_rec1d(coefficients) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 @@ -642,7 +659,10 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - if fold: - res_lo = _unfold_channels(res_lo.unsqueeze(-2), list(ds)).squeeze(-2) + if ds: + res_lo = _unfold_axes(res_lo.unsqueeze(-2), list(ds), 1).squeeze(-2) + + if self.axis != -1: + res_lo = res_lo.swapaxes(self.axis, -1) return res_lo diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index f82cbaa4..1a1642ab 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -4,6 +4,7 @@ """ # Written by moritz ( @ wolter.tech ) in 2021 import sys +from functools import partial from typing import List, Optional, Tuple, Union, cast import numpy as np @@ -12,16 +13,19 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_channels, + _check_axes_argument, + _check_if_tensor, _is_boundary_mode_supported, _is_dtype_supported, - _unfold_channels, + _map_result, + _swap_axes, + _undo_swap_axes, + _unfold_axes, ) from .conv_transform import _get_filter_tensors from .conv_transform_2 import ( - _check_if_tensor, _construct_2d_filt, - _wavedec2d_unfold_channels_2d_list, + _preprocess_tensor_dec2d, _waverec2d_fold_channels_2d_list, ) from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize @@ -245,6 +249,7 @@ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, + axes: Tuple[int, int] = (-2, -1), boundary: str = "qr", separable: bool = True, ): @@ -256,6 +261,8 @@ def __init__( level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. + axes (int, int): A tuple with the axes to transform. + Defaults to (-2, -1). boundary (str): The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr implementation, it is fast but memory hungry. @@ -274,6 +281,11 @@ def __init__( ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + self.axes = tuple(axes) self.level = level self.boundary = boundary self.separable = separable @@ -422,23 +434,11 @@ def __call__( ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - fold = False - if input_signal.dim() == 2: - # add batch dim to unbatched input - input_signal = input_signal.unsqueeze(0) - elif input_signal.dim() == 4: - # we assume the shape [batch_size, color_channels, height, width] - # and fold the color channel - fold = True - ds = input_signal.shape - input_signal = _fold_channels(input_signal) - elif input_signal.dim() != 3: - raise ValueError( - f"Invalid input tensor shape {input_signal.size()}. " - "The input signal is expected to be of the form " - "[batch_size, height, width] or " - "[batch_size, channels, height, width]." - ) + if self.axes != (-2, -1): + input_signal = _swap_axes(input_signal, list(self.axes)) + + input_signal, ds = _preprocess_tensor_dec2d(input_signal) + input_signal = input_signal.squeeze(1) batch_size, height, width = input_signal.shape @@ -539,8 +539,13 @@ def __call__( ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) ) - if fold: - split_list = _wavedec2d_unfold_channels_2d_list(split_list, list(ds)) + if ds: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + split_list = _map_result(split_list, _unfold_axes2) + + if self.axes != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) + split_list = _map_result(split_list, undo_swap_fn) return split_list[::-1] @@ -563,6 +568,7 @@ class MatrixWaverec2(object): def __init__( self, wavelet: Union[Wavelet, str], + axes: Tuple[int, int] = (-2, -1), boundary: str = "qr", separable: bool = True, ): @@ -571,6 +577,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axes (int, int): The axes transformed by waverec2. + Defaults to (-2, -1). boundary (str): The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' option @@ -591,6 +599,12 @@ def __init__( self.boundary = boundary self.separable = separable + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + self.axes = axes + self.ifwt_matrix_list: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = [] @@ -716,10 +730,12 @@ def __call__( by the `MatrixWavedec2`-Object. Returns: - torch.Tensor: The original signal reconstruction of - shape ``[batch_size, height, width]`` or + torch.Tensor: The original signal reconstruction. + For example of shape + ``[batch_size, height, width]`` or ``[batch_size, channels, height, width]`` depending on the input to the forward transform. + and the value of the `axis` argument. Raises: ValueError: If the decomposition level is not a positive integer or if the @@ -727,10 +743,20 @@ def __call__( `MatrixWavedec2` object. """ ll = _check_if_tensor(coefficients[0]) - fold = False - if ll.dim() == 4: - # fold all channels into the batches. - fold = True + + if tuple(self.axes) != (-2, -1): + swap_fn = partial(_swap_axes, axes=list(self.axes)) + coefficients = _map_result(coefficients, swap_fn) + ll = _check_if_tensor(coefficients[0]) + + ds = None + if ll.dim() == 1: + raise ValueError("2d transforms require more than a single input dim.") + elif ll.dim() == 2: + # add batch dim to unbatched input + ll = ll.unsqueeze(0) + elif ll.dim() >= 4: + # avoid the channel sum, fold the channels into batches. coefficients, ds = _waverec2d_fold_channels_2d_list(coefficients) ll = _check_if_tensor(coefficients[0]) @@ -825,7 +851,9 @@ def __call__( if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] - if fold: - ll = _unfold_channels(ll, list(ds)) + if ds: + ll = _unfold_axes(ll, list(ds), 2) + if self.axes != (-2, -1): + ll = _undo_swap_axes(ll, list(self.axes)) return ll diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index def5930e..f27f3dbb 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -9,9 +9,17 @@ from ._util import ( Wavelet, _as_wavelet, + _check_axes_argument, + _check_if_tensor, + _fold_axes, _is_boundary_mode_supported, _is_dtype_supported, + _map_result, + _swap_axes, + _undo_swap_axes, + _unfold_axes, ) +from .conv_transform_3 import _waverec3d_fold_channels_3d_list from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm @@ -47,6 +55,7 @@ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, + axes: Tuple[int, int, int] = (-3, -2, -1), boundary: Optional[str] = "qr", ): """Create a *separable* three-dimensional fast boundary wavelet transform. @@ -70,6 +79,11 @@ def __init__( self.wavelet = _as_wavelet(wavelet) self.level = level self.boundary = boundary + if len(axes) != 3: + raise ValueError("3D transforms work with three axes.") + else: + _check_axes_argument(list(axes)) + self.axes = axes self.input_signal_shape: Optional[Tuple[int, int, int]] = None self.fwt_matrix_list: List[List[torch.Tensor]] = [] @@ -145,8 +159,8 @@ def __call__( """Compute a separable 3d-boundary wavelet transform. Args: - input_signal (torch.Tensor): An input signal of shape - [batch_size, depth, height, width]. + input_signal (torch.Tensor): An input signal. For example + of shape [batch_size, depth, height, width]. Raises: ValueError: If the input dimensions don't work. @@ -156,15 +170,16 @@ def __call__( A list with the approximation coefficients, and a coefficient dict for each scale. """ - if input_signal.dim() == 3: - # add batch dim to unbatched input - input_signal = input_signal.unsqueeze(0) - elif input_signal.dim() != 4: - raise ValueError( - f"Invalid input tensor shape {input_signal.size()}. " - "The input signal is expected to be of the form " - "[batch_size, depth, height, width]." - ) + if self.axes != (-3, -2, -1): + input_signal = _swap_axes(input_signal, list(self.axes)) + + ds = None + if input_signal.dim() < 3: + raise ValueError("At least three dimensions are required for 3d wavedec.") + elif len(input_signal.shape) == 3: + input_signal = input_signal.unsqueeze(1) + else: + input_signal, ds = _fold_axes(input_signal, 3) _, depth, height, width = input_signal.shape @@ -242,6 +257,15 @@ def _split_rec( } split_list.append(coeff_dict) split_list.append(lll) + + if ds: + _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) + split_list = _map_result(split_list, _unfold_axes_fn) + + if self.axes != (-3, -2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) + split_list = _map_result(split_list, undo_swap_fn) + return split_list[::-1] @@ -251,6 +275,7 @@ class MatrixWaverec3(object): def __init__( self, wavelet: Union[Wavelet, str], + axes: Tuple[int, int, int] = (-3, -2, -1), boundary: str = "qr", ): """Compute a three-dimensional separable boundary wavelet synthesis transform. @@ -258,6 +283,8 @@ def __init__( Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. + axes (Tuple[int, int, int]): Transform these axes instead of the + last three. Defaults to (-3, -2, -1). boundary (str): The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' option @@ -269,6 +296,11 @@ def __init__( ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) + if len(axes) != 3: + raise ValueError("3D transforms work with three axes") + else: + _check_axes_argument(list(axes)) + self.axes = axes self.boundary = boundary self.ifwt_matrix_list: List[List[torch.Tensor]] = [] self.input_signal_shape: Optional[Tuple[int, int, int]] = None @@ -366,6 +398,21 @@ def __call__( Raises: ValueError: If the data structure is inconsistent. """ + if self.axes != (-3, -2, -1): + swap_axes_fn = partial(_swap_axes, axes=list(self.axes)) + coefficients = _map_result(coefficients, swap_axes_fn) + + ds = None + # the Union[tensor, dict] idea is coming from pywt. We don't change it here. + res_lll = _check_if_tensor(coefficients[0]) + if res_lll.dim() < 3: + raise ValueError( + "Three dimensional transforms require at least three dimensions." + ) + elif res_lll.dim() >= 5: + coefficients, ds = _waverec3d_fold_channels_3d_list(coefficients) + res_lll = _check_if_tensor(coefficients[0]) + level = len(coefficients) - 1 if type(coefficients[-1]) is dict: depth, height, width = tuple( @@ -433,4 +480,10 @@ def __call__( for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]): lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1)) + if ds: + lll = _unfold_axes(lll, ds, 3) + + if self.axes != (-3, -2, -1): + lll = _undo_swap_axes(lll, list(self.axes)) + return lll diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index b876b265..d9f6019f 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -48,9 +48,10 @@ def __init__( self, data: Optional[torch.Tensor], wavelet: Union[Wavelet, str], - mode: str = "reflect", - boundary_orthogonalization: str = "qr", + mode: Optional[str] = "reflect", maxlevel: Optional[int] = None, + axis: int = -1, + boundary_orthogonalization: str = "qr", ) -> None: """Create a wavelet packet decomposition object. @@ -61,16 +62,19 @@ def __init__( ``[batch_size, time]`` or ``[batch_size, channels, time]``. If None, the object is initialized without performing a decomposition. + The time axis is transformed by default. + Use the ``axis`` argument to choose another dimension. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. - mode (str): The desired padding method. If you select 'boundary', + mode (str, optional): The desired padding method. If you select 'boundary', the sparse matrix backend will be used. Defaults to 'reflect'. - boundary_orthogonalization (str): The orthogonalization method - to use. Only used if `mode` equals 'boundary'. Choose from - 'qr' or 'gramschmidt'. Defaults to 'qr'. maxlevel (int, optional): Value is passed on to `transform`. The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. + axis (int): The axis to transform. Defaults to -1. + boundary_orthogonalization (str): The orthogonalization method + to use. Only used if `mode` equals 'boundary'. Choose from + 'qr' or 'gramschmidt'. Defaults to 'qr'. Example: >>> import torch, pywt, ptwt @@ -95,6 +99,7 @@ def __init__( self._matrix_wavedec_dict: Dict[int, MatrixWavedec] = {} self._matrix_waverec_dict: Dict[int, MatrixWaverec] = {} self.maxlevel: Optional[int] = None + self.axis = axis if data is not None: if len(data.shape) == 1: # add a batch dimension. @@ -109,8 +114,8 @@ def transform( """Calculate the 1d wavelet packet transform for the input data. Args: - data (torch.Tensor): The input data array of shape [time] - or [batch_size, time]. + data (torch.Tensor): The input data array of shape ``[time]`` + or ``[batch_size, time]``. maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. @@ -166,11 +171,13 @@ def _get_wavedec( if self.mode == "boundary": if length not in self._matrix_wavedec_dict.keys(): self._matrix_wavedec_dict[length] = MatrixWavedec( - self.wavelet, level=1, boundary=self.boundary + self.wavelet, level=1, boundary=self.boundary, axis=self.axis ) return self._matrix_wavedec_dict[length] else: - return partial(wavedec, wavelet=self.wavelet, level=1, mode=self.mode) + return partial( + wavedec, wavelet=self.wavelet, level=1, mode=self.mode, axis=self.axis + ) def _get_waverec( self, @@ -179,11 +186,11 @@ def _get_waverec( if self.mode == "boundary": if length not in self._matrix_waverec_dict.keys(): self._matrix_waverec_dict[length] = MatrixWaverec( - self.wavelet, boundary=self.boundary + self.wavelet, boundary=self.boundary, axis=self.axis ) return self._matrix_waverec_dict[length] else: - return partial(waverec, wavelet=self.wavelet) + return partial(waverec, wavelet=self.wavelet, axis=self.axis) def get_level(self, level: int) -> List[str]: """Return the graycode-ordered paths to the filter tree nodes. @@ -257,15 +264,16 @@ def __init__( data: Optional[torch.Tensor], wavelet: Union[Wavelet, str], mode: str = "reflect", + maxlevel: Optional[int] = None, + axes: Tuple[int, int] = (-2, -1), boundary_orthogonalization: str = "qr", separable: bool = False, - maxlevel: Optional[int] = None, ) -> None: """Create a 2D-Wavelet packet tree. Args: - data (torch.tensor, optional): The input data tensor - of shape ``[batch_size, height, width]`` or + data (torch.tensor, optional): The input data tensor. + For example of shape ``[batch_size, height, width]`` or ``[batch_size, channels, height, width]``. If None, the object is initialized without performing a decomposition. @@ -274,15 +282,18 @@ def __init__( mode (str): A string indicating the desired padding mode. If you select 'boundary', the sparse matrix backend is used. Defaults to 'reflect' + maxlevel (int, optional): Value is passed on to `transform`. + The highest decomposition level to compute. If None, the maximum level + is determined from the input data shape. Defaults to None. + axes ([int, int], optional): The tensor axes that should be transformed. + Defaults to (-2, -1). boundary_orthogonalization (str): The orthogonalization method to use in the sparse matrix backend. Only used if `mode` equals 'boundary'. Choose from 'qr' or 'gramschmidt'. Defaults to 'qr'. separable (bool): If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False. - maxlevel (int, optional): Value is passed on to `transform`. - The highest decomposition level to compute. If None, the maximum level - is determined from the input data shape. Defaults to None. + """ self.wavelet = _as_wavelet(wavelet) self.mode = mode @@ -290,6 +301,7 @@ def __init__( self.separable = separable self.matrix_wavedec2_dict: Dict[Tuple[int, ...], MatrixWavedec2] = {} self.matrix_waverec2_dict: Dict[Tuple[int, ...], MatrixWaverec2] = {} + self.axes = axes self.maxlevel: Optional[int] = None if data is not None: @@ -382,6 +394,7 @@ def _get_wavedec( self.matrix_wavedec2_dict[shape] = MatrixWavedec2( self.wavelet, level=1, + axes=self.axes, boundary=self.boundary, separable=self.separable, ) @@ -389,10 +402,18 @@ def _get_wavedec( return fun elif self.separable: return self._transform_fsdict_to_tuple_func( - partial(fswavedec2, wavelet=self.wavelet, level=1, mode=self.mode) + partial( + fswavedec2, + wavelet=self.wavelet, + level=1, + mode=self.mode, + axes=self.axes, + ) ) else: - return partial(wavedec2, wavelet=self.wavelet, level=1, mode=self.mode) + return partial( + wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes + ) def _get_waverec( self, shape: Tuple[int, ...] @@ -405,16 +426,17 @@ def _get_waverec( if shape not in self.matrix_waverec2_dict.keys(): self.matrix_waverec2_dict[shape] = MatrixWaverec2( self.wavelet, + axes=self.axes, boundary=self.boundary, separable=self.separable, ) return self.matrix_waverec2_dict[shape] elif self.separable: return self._transform_tuple_to_fsdict_func( - partial(fswaverec2, wavelet=self.wavelet) + partial(fswaverec2, wavelet=self.wavelet, axes=self.axes) ) else: - return partial(waverec2, wavelet=self.wavelet) + return partial(waverec2, wavelet=self.wavelet, axes=self.axes) def _transform_fsdict_to_tuple_func( self, diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 1495ff27..93dd7872 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -4,14 +4,26 @@ individually using torch.nn.functional.conv1d and it's transpose. """ -from typing import Dict, List, Optional, Union +from functools import partial +from typing import Dict, List, Optional, Tuple, Union import numpy as np import pywt import torch -from ._util import _as_wavelet, _fold_channels, _unfold_channels +from ._util import ( + _as_wavelet, + _check_axes_argument, + _check_if_tensor, + _fold_axes, + _is_dtype_supported, + _map_result, + _swap_axes, + _undo_swap_axes, + _unfold_axes, +) from .conv_transform import wavedec, waverec +from .conv_transform_2 import _preprocess_tensor_dec2d def _separable_conv_dwtn_( @@ -94,30 +106,20 @@ def _separable_conv_idwtn( def _separable_conv_wavedecn( input: torch.Tensor, - wavelet: Union[str, pywt.Wavelet], + wavelet: pywt.Wavelet, mode: str = "reflect", level: Optional[int] = None, ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a multilevel separable padded wavelet analysis transform. Args: - input (torch.Tensor): A tensor of shape [batch, axis_1, ... axis_n]. - Everything but the batch axis will be transformed. - wavelet (Wavelet or str): A pywt wavelet compatible object or - the name of a pywt wavelet. - Please consider the output from ``pywt.wavelist(kind='discrete')`` - for possible choices. - mode (str): The desired padding mode. Padding extends the signal along - the edges. Supported methods are:: - - "reflect", "zero", "constant", "periodic". - - Defaults to "reflect". - level (int): The desired decomposition level. If None the - largest possible decomposition value is used. + input (torch.Tensor): A tensor i.e. of shape [batch,axis_1, ... axis_n]. + wavelet (Wavelet): A pywt wavelet compatible object. + mode (str): The desired padding mode. + level (int): The desired decomposition level. Returns: - List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: _description_ + List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: The wavelet coeffs. """ result: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] approx = input @@ -139,30 +141,30 @@ def _separable_conv_wavedecn( def _separable_conv_waverecn( - coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], - wavelet: Union[str, pywt.Wavelet], + coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + wavelet: pywt.Wavelet, ) -> torch.Tensor: """Separable n-dimensional wavelet synthesis transform. Args: - coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The output as produced by `_separable_conv_wavedecn`. - wavelet (Union[str, pywt.Wavelet]): + wavelet (pywt.Wavelet): The wavelet used by `_separable_conv_wavedecn`. Returns: torch.Tensor: The reconstruction of the original signal. Raises: - ValueError: If the coeff_list is not structured as expected. + ValueError: If the coeffs is not structured as expected. """ - if not isinstance(coeff_list[0], torch.Tensor): + if not isinstance(coeffs[0], torch.Tensor): raise ValueError("approximation tensor must be first in coefficient list.") - if not all(map(lambda x: isinstance(x, dict), coeff_list[1:])): + if not all(map(lambda x: isinstance(x, dict), coeffs[1:])): raise ValueError("All entries after approximation tensor must be dicts.") - approx: torch.Tensor = coeff_list[0] - for level_dict in coeff_list[1:]: + approx: torch.Tensor = coeffs[0] + for level_dict in coeffs[1:]: keys = list(level_dict.keys()) # type: ignore level_dict["a" * max(map(len, keys))] = approx # type: ignore approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore @@ -170,15 +172,16 @@ def _separable_conv_waverecn( def fswavedec2( - input: torch.Tensor, + data: torch.Tensor, wavelet: Union[str, pywt.Wavelet], mode: str = "reflect", level: Optional[int] = None, + axes: Tuple[int, int] = (-2, -1), ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 2D-padded analysis wavelet transform. Args: - input (torch.Tensor): An input signal of shape ``[batch, height, width]`` + data (torch.Tensor): An data signal of shape ``[batch, height, width]`` or ``[batch, channels, height, width]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of @@ -190,9 +193,11 @@ def fswavedec2( This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. + axes ([int, int]): The axes we want to transform, + defaults to (-2, -1). Raises: - ValueError: If the input is not a batched 2D signal. + ValueError: If the data is not a batched 2D signal. Returns: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: @@ -212,42 +217,42 @@ def fswavedec2( >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ - fold = False - if len(input.shape) == 2: - input = input.unsqueeze(0) - elif input.dim() == 4: - # fold channels into batches. - fold = True - ds = list(input.shape) - input = _fold_channels(input) - elif len(input.shape) != 3: - raise ValueError("Batched 2d inputs required for a 2d transform.") - res = _separable_conv_wavedecn(input, wavelet, mode, level) - - if fold: - unfold: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] - for resel in res: - if isinstance(resel, torch.Tensor): - unfold.append(_unfold_channels(resel, ds)) - else: - unfold.append( - {key: _unfold_channels(value, ds) for key, value in resel.items()} - ) - res = unfold + if not _is_dtype_supported(data.dtype): + raise ValueError(f"Input dtype {data.dtype} not supported") + + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) + + wavelet = _as_wavelet(wavelet) + data, ds = _preprocess_tensor_dec2d(data) + data = data.squeeze(1) + res = _separable_conv_wavedecn(data, wavelet, mode, level) + + if ds: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + res = _map_result(res, _unfold_axes2) + + if axes != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + res = _map_result(res, undo_swap_fn) return res def fswavedec3( - input: torch.Tensor, + data: torch.Tensor, wavelet: Union[str, pywt.Wavelet], mode: str = "reflect", level: Optional[int] = None, + axes: Tuple[int, int, int] = (-3, -2, -1), ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 3D-padded analysis wavelet transform. Args: - input (torch.Tensor): An input signal of shape [batch, depth, height, width]. + data (torch.Tensor): An input signal of shape [batch, depth, height, width]. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. @@ -258,6 +263,8 @@ def fswavedec3( This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. + axes (Tuple[int, int, int]): Compute the transform over these axes + instead of the last three. Defaults to (-3, -2, -1). Raises: ValueError: If the input is not a batched 3D signal. @@ -278,30 +285,57 @@ def fswavedec3( >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ - if len(input.shape) == 3: - input = input.unsqueeze(0) - if len(input.shape) != 4: - raise ValueError("Batched 3d inputs required for a 3d transform.") + if not _is_dtype_supported(data.dtype): + raise ValueError(f"Input dtype {data.dtype} not supported") + + if tuple(axes) != (-3, -2, -1): + if len(axes) != 3: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) - return _separable_conv_wavedecn(input, wavelet, mode, level) + wavelet = _as_wavelet(wavelet) + ds = None + if len(data.shape) >= 5: + data, ds = _fold_axes(data, 3) + elif len(data.shape) < 4: + raise ValueError("At lest four input dimensions are required.") + data = data.squeeze(1) + res = _separable_conv_wavedecn(data, wavelet, mode, level) + + if ds: + _unfold_axes3 = partial(_unfold_axes, ds=ds, keep_no=3) + res = _map_result(res, _unfold_axes3) + + if axes != (-3, -2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + res = _map_result(res, undo_swap_fn) + + return res def fswaverec2( - coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], + axes: Tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Compute a fully separable 2D-padded synthesis wavelet transform. Args: - coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. + axes (Tuple[int, int]): Compute the transform over these + axes instead of the last two. Defaults to (-2, -1). Returns: torch.Tensor: A reconstruction of the signal encoded in the wavelet coefficients. + Raises: + ValueError: If the axes argument is not a tuple of two integers. + Example: >>> import torch >>> import ptwt @@ -310,25 +344,63 @@ def fswaverec2( >>> rec = ptwt.fswaverec2(coeff, "haar") """ - return _separable_conv_waverecn(coeff_list, wavelet) + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + swap_fn = partial(_swap_axes, axes=list(axes)) + coeffs = _map_result(coeffs, swap_fn) + + ds = None + wavelet = _as_wavelet(wavelet) + + res_ll = _check_if_tensor(coeffs[0]) + torch_dtype = res_ll.dtype + + if res_ll.dim() >= 4: + # avoid the channel sum, fold the channels into batches. + ds = _check_if_tensor(coeffs[0]).shape + coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) + res_ll = _check_if_tensor(coeffs[0]) + + if not _is_dtype_supported(torch_dtype): + raise ValueError(f"Input dtype {torch_dtype} not supported") + + res_ll = _separable_conv_waverecn(coeffs, wavelet) + + if ds: + res_ll = _unfold_axes(res_ll, list(ds), 2) + + if axes != (-2, -1): + res_ll = _undo_swap_axes(res_ll, list(axes)) + + return res_ll def fswaverec3( - coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], + axes: Tuple[int, int, int] = (-3, -2, -1), ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: - coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): + coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. + axes (Tuple[int, int, int]): Compute the transform over these axes + instead of the last three. Defaults to (-3, -2, -1). Returns: torch.Tensor: A reconstruction of the signal encoded in the wavelet coefficients. + Raises: + ValueError: If the axes argument is not a tuple with + three ints. + Example: >>> import torch >>> import ptwt @@ -337,4 +409,34 @@ def fswaverec3( >>> rec = ptwt.fswaverec3(coeff, "haar") """ - return _separable_conv_waverecn(coeff_list, wavelet) + if tuple(axes) != (-3, -2, -1): + if len(axes) != 3: + raise ValueError("2D transforms work with two axes.") + else: + _check_axes_argument(list(axes)) + swap_fn = partial(_swap_axes, axes=list(axes)) + coeffs = _map_result(coeffs, swap_fn) + + ds = None + wavelet = _as_wavelet(wavelet) + res_ll = _check_if_tensor(coeffs[0]) + torch_dtype = res_ll.dtype + + if res_ll.dim() >= 5: + # avoid the channel sum, fold the channels into batches. + ds = _check_if_tensor(coeffs[0]).shape + coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 3)[0]) + res_ll = _check_if_tensor(coeffs[0]) + + if not _is_dtype_supported(torch_dtype): + raise ValueError(f"Input dtype {torch_dtype} not supported") + + res_ll = _separable_conv_waverecn(coeffs, wavelet) + + if ds: + res_ll = _unfold_axes(res_ll, list(ds), 3) + + if axes != (-3, -2, -1): + res_ll = _undo_swap_axes(res_ll, list(axes)) + + return res_ll diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index c6787c6f..d7663c28 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -18,6 +18,7 @@ from tests._mackey_glass import MackeyGenerator +@pytest.mark.slow @pytest.mark.parametrize("wavelet_string", ["db1", "db2", "db3", "db4", "db5", "sym5"]) @pytest.mark.parametrize("level", [1, 2, None]) @pytest.mark.parametrize("length", [64, 65]) @@ -110,6 +111,38 @@ def test_orth_wavelet(): assert np.allclose(res.detach().numpy(), mackey_data_1.numpy()) +@pytest.mark.parametrize("level", [1, 2, 3, None]) +@pytest.mark.parametrize("shape", [(64,), (1, 64), (3, 2, 64), (4, 3, 2, 64)]) +def test_1d_multibatch(level, shape): + """Test 1D conv support for multiple inert batch dimensions.""" + data = torch.randn(*shape, dtype=torch.float64) + ptwt_coeff = wavedec(data, "haar", level=level) + pywt_coeff = pywt.wavedec(data, "haar", level=level, mode="reflect") + + # test coefficients + test_list = _compare_coeffs(ptwt_coeff, pywt_coeff) + assert all(test_list) + + # test reconstruction + rec = waverec(ptwt_coeff, "haar") + assert torch.allclose(rec, data) + + +@pytest.mark.parametrize("axis", [-1, 0, 1, 2]) +def test_1d_axis_arg(axis): + """Ensure the axis argument works as expected.""" + data = torch.randn([16, 16, 16], dtype=torch.float64) + + ptwtcs = wavedec(data, "haar", level=2, axis=axis) + pywtcs = pywt.wavedec(data, "haar", level=2, axis=axis) + + test_list = _compare_coeffs(ptwtcs, pywtcs) + assert all(test_list) + + rec = waverec(ptwtcs, "haar", axis=axis) + assert torch.allclose(rec, data) + + def test_2d_haar_lvl1(): """Test a 2d-Haar wavelet conv-fwt.""" # ------------------------- 2d haar wavelet tests ----------------------- @@ -211,8 +244,8 @@ def test_2d_wavedec_rec(wavelet_str, level, size, mode): @pytest.mark.parametrize( "size", [(50, 20, 128, 128), (49, 21, 128, 128), (4, 5, 64, 64)] ) -@pytest.mark.parametrize("level", [1, 3, None]) -@pytest.mark.parametrize("wavelet", ["haar", "db2", "sym3"]) +@pytest.mark.parametrize("level", [1, None]) +@pytest.mark.parametrize("wavelet", ["haar", "sym3"]) def test_input_4d(size, level, wavelet): """Test the error for 4d inputs to wavedec2.""" data = torch.randn(*size).type(torch.float64) @@ -251,3 +284,85 @@ def test_input_1d_dimension_error(): with pytest.raises(ValueError): data = torch.randn(50) wavedec2(data, "haar", 4) + + +def _compare_coeffs(ptwt_res, pywt_res): + """Compare coefficient lists. + + Args: + ptwt_res: Our result list. + pywt_res: A pyt result list. + + Returns: + A list with bools from allclose. + """ + test_list = [] + for ptwtcs, pywtcs in zip(ptwt_res, pywt_res): + if isinstance(ptwtcs, tuple): + test_list.extend( + tuple( + np.allclose(ptwtc.numpy(), pywtc) + for ptwtc, pywtc in zip(ptwtcs, pywtcs) + ) + ) + else: + test_list.append(np.allclose(ptwtcs.numpy(), pywtcs)) + return test_list + + +@pytest.mark.slow +@pytest.mark.parametrize( + "size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)] +) +def test_2d_multidim_input(size): + """Test the error for multi-dimensional inputs to wavedec2.""" + data = torch.randn(*size, dtype=torch.float64) + wavelet = "db2" + level = 3 + + pt_res = wavedec2(data, wavelet=wavelet, level=level, mode="reflect") + pywt_res = pywt.wavedec2(data.numpy(), wavelet=wavelet, level=level, mode="reflect") + rec = waverec2(pt_res, wavelet) + + # test coefficients + test_list = _compare_coeffs(pt_res, pywt_res) + assert all(test_list) + + # test reconstruction. + assert np.allclose( + data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]] + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)]) +def test_2d_axis_argument(axes): + """Ensure the axes argument works as expected.""" + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + + ptwt_coeff = wavedec2(data, "db2", level=3, mode="reflect", axes=axes) + pywt_coeff = pywt.wavedec2(data, "db2", level=3, mode="reflect", axes=axes) + rec = waverec2(ptwt_coeff, "db2", axes=axes) + + # test coefficients + test_list = _compare_coeffs(ptwt_coeff, pywt_coeff) + assert all(test_list) + + # test reconstruction. + assert np.allclose( + data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]] + ) + + +def test_2d_axis_error_axes_count(): + """Check the error for too many axes.""" + with pytest.raises(ValueError): + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + wavedec2(data, "haar", 1, axes=(1, 2, 3)) + + +def test_2d_axis_error_axes_repetition(): + """Check the error for axes repetition.""" + with pytest.raises(ValueError): + data = torch.randn([32, 32, 32, 32], dtype=torch.float64) + wavedec2(data, "haar", 1, axes=(2, 2)) diff --git a/tests/test_convolution_fwt_3.py b/tests/test_convolution_fwt_3.py index fead2d84..13b1ebbd 100644 --- a/tests/test_convolution_fwt_3.py +++ b/tests/test_convolution_fwt_3.py @@ -83,3 +83,76 @@ def test_waverec3(shape: list, wavelet: str, level: int, mode: str) -> None: assert np.allclose( rec.numpy()[..., : shape[1], : shape[2], : shape[3]], data.numpy() ) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "size", [[5, 32, 32, 32], [4, 3, 32, 32, 32], [1, 1, 1, 32, 32, 32]] +) +@pytest.mark.parametrize("level", [1, 2, None]) +@pytest.mark.parametrize("wavelet", ["haar", "sym3", "db3"]) +@pytest.mark.parametrize("mode", ["zero", "symmetric", "reflect"]) +def test_multidim_input(size: List[int], level: int, wavelet: str, mode: str): + """Ensure correct folding of multidimensional inputs.""" + data = torch.randn(size, dtype=torch.float64) + ptwc = ptwt.wavedec3(data, wavelet, level=level, mode=mode) + # batch_list = [] + # for batch_no in range(data.shape[0]): + # pywc = pywt.wavedecn(data[batch_no].numpy(), wavelet, level=level, mode=mode) + # batch_list.append(pywc) + # cat_pywc = _cat_batch_list(batch_list) + cat_pywc = pywt.wavedecn(data, wavelet, level=level, mode=mode, axes=[-3, -2, -1]) + + # ensure ptwt and pywt coefficients are identical. + test_list = [] + for a, b in zip(ptwc, cat_pywc): + if type(a) is torch.Tensor: + test_list.append(np.allclose(a, b)) + else: + test_list.extend([np.allclose(a[key], b[key]) for key in a.keys()]) + + assert all(test_list) + + rec = ptwt.waverec3(ptwc, wavelet) + assert np.allclose( + rec.numpy()[..., : size[-3], : size[-2], : size[-1]], data.numpy() + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("axes", [[-3, -2, -1], [0, 2, 1]]) +@pytest.mark.parametrize("level", [1, 2, None]) +@pytest.mark.parametrize("mode", ["zero", "symmetric", "reflect"]) +def test_axes_arg_3d(axes: List[int], level: int, mode: str) -> None: + """Test axes argument support.""" + wavelet = "db3" + data = torch.randn([16, 16, 16, 16, 16], dtype=torch.float64) + ptwc = ptwt.wavedec3(data, wavelet, level=level, mode=mode, axes=axes) + cat_pywc = pywt.wavedecn(data, wavelet, level=level, mode=mode, axes=axes) + + # ensure ptwt and pywt coefficients are identical. + test_list = [] + for a, b in zip(ptwc, cat_pywc): + if type(a) is torch.Tensor: + test_list.append(np.allclose(a, b)) + else: + test_list.extend([np.allclose(a[key], b[key]) for key in a.keys()]) + + assert all(test_list) + + rec = ptwt.waverec3(ptwc, wavelet, axes=axes) + assert np.allclose(data, rec) + + +def test_2d_dimerror(): + """Check the error for too many axes.""" + with pytest.raises(ValueError): + data = torch.randn([32, 32], dtype=torch.float64) + ptwt.wavedec3(data, "haar") + + +def test_1d_dimerror(): + """Check the error for too many axes.""" + with pytest.raises(ValueError): + data = torch.randn([32], dtype=torch.float64) + ptwt.wavedec3(data, "haar") diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index 66116785..0cc82484 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -183,11 +183,11 @@ def test_matrix_transform_1d_rebuild(wavelet_str: str, boundary: str): ) -def test_4d_input_to_1d_transform_dimension_error(): - """Test the error for 1d inputs to the MatrixWavedec __call__.""" +def test_4d_invalid_axis_error(): + """Test the error for 1d axis arguments.""" with pytest.raises(ValueError): data = torch.randn(50, 50, 50, 50) - matrix_wavedec_1d = MatrixWavedec("haar", 4) + matrix_wavedec_1d = MatrixWavedec("haar", axis=(1, 2)) matrix_wavedec_1d(data) @@ -209,3 +209,21 @@ def test_matrix1d_batch_channel(size): rec = matrix_waverec_2d(ptwt_coeff) assert np.allclose(data.numpy(), rec.numpy()) + + +@pytest.mark.parametrize("axis", (0, 1, 2, 3, 4)) +def test_axis_1d(axis): + """Ensure the axis argument is supported correctly.""" + data = torch.randn(24, 24, 24, 24, 24).type(torch.float64) + matrix_wavedec = MatrixWavedec(wavelet="haar", level=3, axis=axis) + coeff = matrix_wavedec(data) + coeff_pywt = pywt.wavedec(data.numpy(), wavelet="haar", level=3, axis=axis) + assert len(coeff) == len(coeff_pywt) + assert all( + [np.allclose(coeff, coeff_pywt) for coeff, coeff_pywt in zip(coeff, coeff_pywt)] + ) + + matrix_waverec = MatrixWaverec("haar", axis=axis) + + rec = matrix_waverec(coeff) + assert np.allclose(rec, data) diff --git a/tests/test_matrix_fwt_2.py b/tests/test_matrix_fwt_2.py index cb413d2a..d975a32e 100644 --- a/tests/test_matrix_fwt_2.py +++ b/tests/test_matrix_fwt_2.py @@ -14,6 +14,7 @@ construct_boundary_a2, construct_boundary_s2, ) +from tests.test_convolution_fwt import _compare_coeffs @pytest.mark.parametrize("size", [(16, 16), (16, 8), (8, 16)]) @@ -226,3 +227,21 @@ def test_empty_inverse_operators(operator) -> None: matrixifwt = operator("haar") with pytest.raises(ValueError): _ = matrixifwt.sparse_ifwt_operator + + +@pytest.mark.slow +@pytest.mark.parametrize("axes", ((-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0))) +def test_axes_2d(axes): + """Ensure the axes argument is supported correctly.""" + # TODO: write me. + data = torch.randn(24, 24, 24, 24, 24).type(torch.float64) + matrix_wavedec2 = MatrixWavedec2(wavelet="haar", level=3, axes=axes) + coeff = matrix_wavedec2(data) + coeff_pywt = pywt.wavedec2(data.numpy(), wavelet="haar", level=3, axes=axes) + assert len(coeff) == len(coeff_pywt) + assert _compare_coeffs(coeff, coeff_pywt) + + matrix_waverec2 = MatrixWaverec2("haar", axes=axes) + + rec = matrix_waverec2(coeff) + assert np.allclose(rec, data) diff --git a/tests/test_matrix_fwt_3.py b/tests/test_matrix_fwt_3.py index c49cf86d..09790810 100644 --- a/tests/test_matrix_fwt_3.py +++ b/tests/test_matrix_fwt_3.py @@ -1,4 +1,5 @@ """Test the 3d matrix-fwt code.""" +from typing import List import numpy as np import pytest @@ -82,3 +83,29 @@ def test_boundary_wavedec3_inverse(level, shape): assert np.allclose( test_data.numpy(), rec[:, : shape[0], : shape[1], : shape[2]].numpy() ) + + +@pytest.mark.slow +@pytest.mark.parametrize("axes", [[-3, -2, -1], [0, 2, 1]]) +@pytest.mark.parametrize("level", [1, 2, None]) +def test_axes_arg_matrix_3d(axes: List[int], level: int) -> None: + """Test axes 3d matmul argument support.""" + wavelet = "haar" + data = torch.randn([16, 16, 16, 16, 16], dtype=torch.float64) + ptwc = MatrixWavedec3(wavelet, level=level, axes=axes)(data) + pywc = pywt.wavedecn(data, wavelet, level=level, axes=axes) + + # ensure ptwt and pywt coefficients are identical. + test_list = [] + for a, b in zip(ptwc, pywc): + if type(a) is torch.Tensor: + test_list.append(np.allclose(a, b)) + else: + for key in a.keys(): + test_list.append(np.allclose(b[key], a[key].numpy())) + + assert all(test_list) + + # test inversion + rec = MatrixWaverec3(wavelet, axes=axes)(ptwc) + assert np.allclose(data, rec.numpy()) diff --git a/tests/test_packets.py b/tests/test_packets.py index 0e5d698f..488b8035 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -289,38 +289,46 @@ def test_access_errors_2d(): twp["a" * 100] -@pytest.mark.parametrize("level", [1, 2, 3]) +@pytest.mark.slow +@pytest.mark.parametrize("level", [1, 3]) @pytest.mark.parametrize("base_key", ["a", "d"]) -@pytest.mark.parametrize("shape", [[1, 63], [3, 2, 64], [128]]) +@pytest.mark.parametrize("shape", [[1, 64, 63], [3, 64, 64], [1, 128]]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) -def test_inverse_packet_1d(level, base_key, shape, wavelet): +@pytest.mark.parametrize("axis", (1, -1)) +def test_inverse_packet_1d(level, base_key, shape, wavelet, axis): """Test the 1d reconstruction code.""" signal = np.random.randn(*shape) mode = "reflect" - wp = pywt.WaveletPacket(signal, wavelet, mode=mode, maxlevel=level) - ptwp = WaveletPacket(torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level) + wp = pywt.WaveletPacket(signal, wavelet, mode=mode, maxlevel=level, axis=axis) + ptwp = WaveletPacket( + torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level, axis=axis + ) wp[base_key * level].data *= 0 ptwp[base_key * level].data *= 0 wp.reconstruct(update=True) ptwp.reconstruct() - assert np.allclose(wp[""].data, ptwp[""].numpy()[..., : shape[-1]]) + assert np.allclose(wp[""].data, ptwp[""].numpy()[..., : shape[-2], : shape[-1]]) +@pytest.mark.slow @pytest.mark.parametrize("level", [1, 3]) @pytest.mark.parametrize("base_key", ["a", "h", "d"]) -@pytest.mark.parametrize("size", [(1, 32, 32), (2, 1, 31, 64)]) +@pytest.mark.parametrize("size", [(32, 32, 32), (32, 32, 31, 64)]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) -def test_inverse_packet_2d(level, base_key, size, wavelet): +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) +def test_inverse_packet_2d(level, base_key, size, wavelet, axes): """Test the 2d reconstruction code.""" signal = np.random.randn(*size) mode = "reflect" - wp = pywt.WaveletPacket2D(signal, wavelet, mode=mode, maxlevel=level) - ptwp = WaveletPacket2D(torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level) + wp = pywt.WaveletPacket2D(signal, wavelet, mode=mode, maxlevel=level, axes=axes) + ptwp = WaveletPacket2D( + torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level, axes=axes + ) wp[base_key * level].data *= 0 ptwp[base_key * level].data *= 0 wp.reconstruct(update=True) ptwp.reconstruct() - assert np.allclose(wp[""].data, ptwp[""].numpy()[:, : size[1], : size[2]]) + assert np.allclose(wp[""].data, ptwp[""].numpy()[: size[0], : size[1], : size[2]]) def test_inverse_boundary_packet_1d(): @@ -351,3 +359,21 @@ def test_inverse_boundary_packet_2d(): wp.reconstruct(update=True) ptwp.reconstruct() assert np.allclose(wp[""].data, ptwp[""].numpy()[:, : size[0], : size[1]]) + + +@pytest.mark.slow +@pytest.mark.parametrize("axes", ((-2, -1), (1, 2), (2, 1))) +def test_separable_conv_packets_2d(axes): + """Ensure the 2d separable conv code is ok.""" + wavelet = "db2" + signal = np.random.randn(1, 32, 32, 32) + ptwp = WaveletPacket2D( + torch.from_numpy(signal), + wavelet, + mode="reflect", + maxlevel=2, + axes=axes, + separable=True, + ) + ptwp.reconstruct() + assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) diff --git a/tests/test_separable_conv_fwt.py b/tests/test_separable_conv_fwt.py index 10d185e0..07afac4b 100644 --- a/tests/test_separable_conv_fwt.py +++ b/tests/test_separable_conv_fwt.py @@ -83,12 +83,17 @@ def test_example_fs3d(shape, wavelet): # test separable conv and mamul consistency for the Haar case. @pytest.mark.slow @pytest.mark.parametrize("level", [1, 2, 3, None]) -@pytest.mark.parametrize("shape", [[5, 128, 128], [3, 2, 64, 64], [1, 1, 64, 64]]) -def test_conv_mm_2d(level, shape): +@pytest.mark.parametrize( + "shape", [[1, 64, 128, 128], [1, 3, 64, 64, 64], [2, 1, 64, 64, 64]] +) +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (2, 3), (3, 2)]) +def test_conv_mm_2d(level, shape, axes): """Compare mm and conv fully separable results.""" data = torch.randn(*shape).type(torch.float64) - fs_conv_coeff = fswavedec2(data, "haar", level=level) - fs_mm_coeff = MatrixWavedec2("haar", level, separable=True)(data) + fs_conv_coeff = fswavedec2(data, "haar", level=level, axes=axes) + fs_mm_coeff = MatrixWavedec2( + wavelet="haar", level=level, separable=True, axes=axes + )(data) # compare coefficients assert len(fs_conv_coeff) == len(fs_mm_coeff) for c_conv, c_mm in zip(fs_conv_coeff, fs_mm_coeff): @@ -101,18 +106,19 @@ def test_conv_mm_2d(level, shape): np.allclose(c_el_conv.numpy(), c_el_mm.numpy()) for c_el_conv, c_el_mm in zip(c_conv_list, c_mm) ) - rec = fswaverec2(fs_conv_coeff, "haar") + rec = fswaverec2(fs_conv_coeff, "haar", axes=axes) assert np.allclose(data.numpy(), rec.numpy()) @pytest.mark.slow @pytest.mark.parametrize("level", [1, 2, 3, None]) -def test_conv_mm_3d(level): +@pytest.mark.parametrize("axes", [(-3, -2, -1), (-1, -2, -3), (2, 3, 1)]) +@pytest.mark.parametrize("shape", [(5, 64, 128, 256)]) +def test_conv_mm_3d(level, axes, shape): """Compare mm and conv 3d fully separable results.""" - shape = (5, 128, 128, 128) data = torch.randn(*shape).type(torch.float64) - fs_conv_coeff = fswavedec3(data, "haar", level=level) - fs_mm_coeff = MatrixWavedec3("haar", level)(data) + fs_conv_coeff = fswavedec3(data, "haar", level=level, axes=axes) + fs_mm_coeff = MatrixWavedec3("haar", level, axes=axes)(data) # compare coefficients assert len(fs_conv_coeff) == len(fs_mm_coeff) for c_conv, c_mm in zip(fs_conv_coeff, fs_mm_coeff): @@ -121,5 +127,5 @@ def test_conv_mm_3d(level): else: keys = c_conv.keys() assert all(np.allclose(c_conv[key], c_mm[key]) for key in keys) - rec = fswaverec3(fs_conv_coeff, "haar") + rec = fswaverec3(fs_conv_coeff, "haar", axes=axes) assert np.allclose(data.numpy(), rec.numpy()) diff --git a/tests/test_util.py b/tests/test_util.py index ea579164..a6b6016c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,10 +8,10 @@ from src.ptwt._util import ( _as_wavelet, - _fold_channels, + _fold_axes, _pad_symmetric, _pad_symmetric_1d, - _unfold_channels, + _unfold_axes, ) @@ -70,11 +70,12 @@ def test_pad_symmetric(size, pad_list): assert np.allclose(my_pad.numpy(), np_pad) +@pytest.mark.parametrize("keep_no", [1, 2, 3]) @pytest.mark.parametrize("size", [[20, 21, 22, 23], [1, 2, 3, 4], [4, 3, 2, 1]]) -def test_fold(size): +def test_fold(keep_no, size): """Ensure channel folding works as expected.""" array = torch.randn(*size).type(torch.float64) - folded = _fold_channels(array) - assert tuple(folded.shape) == (size[0] * size[1], size[2], size[3]) - rec = _unfold_channels(folded, size) + folded, ds = _fold_axes(array, keep_no) + assert len(folded.shape) == keep_no + 1 + rec = _unfold_axes(folded, size, keep_no) np.allclose(array.numpy(), rec.numpy())