Skip to content

Commit

Permalink
add axes support for conv2d along with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
loki-veera committed Sep 21, 2023
1 parent 8e1a932 commit 1c9778c
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 38 deletions.
13 changes: 8 additions & 5 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int
The folded result array, and the shape of the original input.
"""
dshape = list(data.shape)
return torch.reshape(data, [np.prod(dshape[:-keep_no])] + dshape[-keep_no:]), dshape
return (
torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]),
dshape,
)


def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor:
"""Unfold i.e. [batch*channel, height, widht] into [batch, channel, height, width]."""
"""Unfold [batch*channel, height, widht] into [batch, channel, height, width]."""
return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:]))


Expand Down Expand Up @@ -167,11 +170,11 @@ def _get_transpose_order(
def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
return torch.transpose(data, front + back)
return torch.permute(data, front + back)


def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
restore_sorted = torch.argsort(torch.tensor(front + back))
return torch.transpose(data, restore_sorted)
restore_sorted = torch.argsort(torch.tensor(front + back)).tolist()
return torch.permute(data, restore_sorted)
117 changes: 84 additions & 33 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
"""


from typing import Any, List, Optional, Tuple, Union
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union

import pywt
import torch

from ._util import (
Wavelet,
_as_wavelet,
_fold_channels,
_check_axes_argument,
_fold_axes,
_get_len,
_is_dtype_supported,
_outer,
_pad_symmetric,
_swap_axes,
_undo_swap_axes,
_unfold_axes,
_unfold_channels,
)
from .conv_transform import (
Expand Down Expand Up @@ -122,23 +127,42 @@ def _waverec2d_fold_channels_2d_list(
ds = list(_check_if_tensor(coeffs[0]).shape)
for coeff in coeffs:
if isinstance(coeff, torch.Tensor):
fold_coeffs.append(_fold_channels(coeff))
fold_coeffs.append(_fold_axes(coeff, 2)[0])
else:
fold_coeffs.append(
(
_fold_channels(coeff[0]),
_fold_channels(coeff[1]),
_fold_channels(coeff[2]),
_fold_axes(coeff[0], 2)[0],
_fold_axes(coeff[1], 2)[0],
_fold_axes(coeff[2], 2)[0],
)
)
return fold_coeffs, ds


def _preprocess_tensor_dec2d(
data: torch.Tensor,
) -> Tuple[torch.Tensor, Union[List[int], None]]:
# Preprocess multidimensional input.
ds = None
if len(data.shape) == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif len(data.shape) >= 4:
data, ds = _fold_axes(data, 2)
data = data.unsqueeze(1)
elif len(data.shape) == 1:
raise ValueError("More than one input dimension required.")
return data, ds


def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "reflect",
level: Optional[int] = None,
axes: Tuple[int, int] = (-2, -1),
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""Non-separated two-dimensional wavelet transform. Only the last two axes change.
Expand All @@ -157,6 +181,8 @@ def wavedec2(
This function defaults to "reflect".
level (int): The number of desired scales.
Defaults to None.
axes (Tuple[int, int]): Compute the transform over these axes instead of the
last two. Defaults to (-2, -1).
Returns:
list: A list containing the wavelet coefficients.
Expand All @@ -183,29 +209,17 @@ def wavedec2(
>>> level=2, mode="zero")
"""
fold = False
if data.dim() == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif data.dim() == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif data.dim() == 4:
# avoid the channel sum, fold the channels into batches.
fold = True
ds = data.shape
data = _fold_channels(data).unsqueeze(1)
elif data.dim() == 1:
raise ValueError("Wavedec2 needs more than one input dimension to work.")
else:
raise ValueError(
"Wavedec2 does not support more than four input dimensions. \
Optionally-batched two-dimensional inputs work."
)

if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
data = _swap_axes(data, list(axes))

wavelet = _as_wavelet(wavelet)
data, ds = _preprocess_tensor_dec2d(data)
dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
Expand All @@ -225,11 +239,35 @@ def wavedec2(
to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1))
result_lst.append(to_append)
result_lst.append(res_ll.squeeze(1))
result_lst.reverse()

if ds:
_unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2)
result_lst = _map_result(result_lst, _unfold_axes2)

if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
result_lst = _map_result(result_lst, undo_swap_fn)

if fold:
result_lst = _wavedec2d_unfold_channels_2d_list(result_lst, list(ds))
return result_lst

return result_lst[::-1]

def _map_result(
data: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
function: Callable[[Any], torch.Tensor],
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
# Apply the given function to the input list of tensor and tuples.
result_lst: List[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
] = []
for element in data:
if isinstance(element, torch.Tensor):
result_lst.append(function(element))
else:
result_lst.append(
(function(element[0]), function(element[1]), function(element[2]))
)
return result_lst


def _check_if_tensor(to_check: Any) -> torch.Tensor:
Expand All @@ -245,6 +283,7 @@ def _check_if_tensor(to_check: Any) -> torch.Tensor:
def waverec2(
coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
wavelet: Union[Wavelet, str],
axes: Tuple[int, int] = (-2, -1),
) -> torch.Tensor:
"""Reconstruct a signal from wavelet coefficients.
Expand All @@ -258,6 +297,8 @@ def waverec2(
and D diagonal coefficients.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
axes (Tuple[int, int]): Compute the transform over these axes instead of the
last two. Defaults to (-2, -1).
Returns:
torch.Tensor: The reconstructed signal of shape ``[batch, height, width]`` or
Expand All @@ -279,16 +320,23 @@ def waverec2(
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
"""
if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
_check_axes_argument(list(axes))
swap_fn = partial(_swap_axes, axes=list(axes))
coeffs = _map_result(coeffs, swap_fn)

ds = None
wavelet = _as_wavelet(wavelet)

res_ll = _check_if_tensor(coeffs[0])
torch_device = res_ll.device
torch_dtype = res_ll.dtype

fold = False
if res_ll.dim() == 4:
if res_ll.dim() >= 4:
# avoid the channel sum, fold the channels into batches.
fold = True
coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs)
res_ll = _check_if_tensor(coeffs[0])

Expand Down Expand Up @@ -348,7 +396,10 @@ def waverec2(
if padr > 0:
res_ll = res_ll[..., :-padr]

if fold:
res_ll = _unfold_channels(res_ll, list(ds))
if ds:
res_ll = _unfold_axes(res_ll, list(ds), 2)

if axes != (-2, -1):
res_ll = _undo_swap_axes(res_ll, list(axes))

return res_ll
71 changes: 71 additions & 0 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,74 @@ def test_input_1d_dimension_error():
with pytest.raises(ValueError):
data = torch.randn(50)
wavedec2(data, "haar", 4)


def _compare_coeffs(ptwt_res, pywt_res):
test_list = []
for ptwtcs, pywtcs in zip(ptwt_res, pywt_res):
if isinstance(ptwtcs, tuple):
test_list.extend(
tuple(
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwtcs, pywtcs)
)
)
else:
test_list.append(np.allclose(ptwtcs.numpy(), pywtcs))
return test_list


@pytest.mark.parametrize(
"size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)]
)
def test_multidim_input(size):
"""Test the error for multi-dimensional inputs to wavedec2."""
data = torch.randn(*size, dtype=torch.float64)
wavelet = "db2"
level = 3

pt_res = wavedec2(data, wavelet=wavelet, level=level, mode="reflect")
pywt_res = pywt.wavedec2(data.numpy(), wavelet=wavelet, level=level, mode="reflect")
rec = waverec2(pt_res, wavelet)

# test coefficients
test_list = _compare_coeffs(pt_res, pywt_res)
assert all(test_list)

# test reconstruction.
assert np.allclose(
data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]]
)


@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)])
def test_axis_argument(axes):
"""Ensure the axes argument works as expected."""
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)

ptwt_coeff = wavedec2(data, "db2", level=3, mode="reflect", axes=axes)
pywt_coeff = pywt.wavedec2(data, "db2", level=3, mode="reflect", axes=axes)
rec = waverec2(ptwt_coeff, "db2", axes=axes)

# test coefficients
test_list = _compare_coeffs(ptwt_coeff, pywt_coeff)
assert all(test_list)

# test reconstruction.
assert np.allclose(
data.numpy(), rec.numpy()[..., : data.shape[-2], : data.shape[-1]]
)


def test_axis_error_axes_count():
"""Check the error for too many axes."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
wavedec2(data, "haar", 1, axes=(1, 2, 3))


def test_axis_error_axes_repetition():
"""Check the error for axes repetition."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
wavedec2(data, "haar", 1, axes=(2, 2))

0 comments on commit 1c9778c

Please sign in to comment.