Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axis support #67

Merged
merged 23 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Tests

on: [ push, pull_request ]
on: [ push ]

jobs:
tests:
Expand All @@ -20,7 +20,7 @@ jobs:
run: pip install nox
- name: Test with pytest
run:
nox -s test
nox -s fast-test
lint:
name: lint
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
:alt: GitHub Actions

.. image:: https://readthedocs.org/projects/pytorch-wavelet-toolbox/badge/?version=latest
:target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/?badge=latest
:target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/ptwt.html
:alt: Documentation Status

.. image:: https://img.shields.io/pypi/pyversions/ptwt
Expand Down
90 changes: 76 additions & 14 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility methods to compute wavelet decompositions from a dataset."""
from typing import List, Optional, Protocol, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union

import numpy as np
import pywt
import torch

Expand Down Expand Up @@ -103,19 +104,80 @@ def _pad_symmetric(
return signal


def _fold_channels(data: torch.Tensor) -> torch.Tensor:
"""Fold [batch, channel, height width] into [batch*channel, height, widht]."""
ds = data.shape
return torch.reshape(
data,
[
ds[0] * ds[1],
ds[2],
ds[3],
],
def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int]]:
"""Fold unchanged leading dimensions into a single batch dimension.

Args:
data ( torch.Tensor): The input data array.
keep_no (int): The number of dimensions to keep.

Returns:
Tuple[ torch.Tensor, List[int]]:
The folded result array, and the shape of the original input.
"""
dshape = list(data.shape)
return (
torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]),
dshape,
)


def _unfold_channels(data: torch.Tensor, ds: List[int]) -> torch.Tensor:
"""Unfold [batch*channel, height, widht] into [batch, channel, height, width]."""
return torch.reshape(data, [ds[0], ds[1], data.shape[1], data.shape[2]])
def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor:
"""Unfold i.e. [batch*channel,height,widht] to [batch,channel,height,width]."""
return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:]))


def _check_if_tensor(array: Any) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
raise ValueError(
"First element of coeffs must be the approximation coefficient tensor."
)
return array


def _check_axes_argument(axes: List[int]) -> None:
if len(set(axes)) != len(axes):
raise ValueError("Cant transform the same axis twice.")


def _get_transpose_order(
axes: List[int], data_shape: List[int]
) -> Tuple[List[int], List[int]]:
axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes))
all_axes = list(range(len(data_shape)))
remove_transformed = list(filter(lambda a: a not in axes, all_axes))
return remove_transformed, axes


def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
return torch.permute(data, front + back)


def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
restore_sorted = torch.argsort(torch.tensor(front + back)).tolist()
return torch.permute(data, restore_sorted)


def _map_result(
data: List[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any
function: Callable[[Any], torch.Tensor],
) -> List[Union[torch.Tensor, Any]]:
# Apply the given function to the input list of tensor and tuples.
result_lst: List[Union[torch.Tensor, Any]] = []
for element in data:
if isinstance(element, torch.Tensor):
result_lst.append(function(element))
elif isinstance(element, tuple):
result_lst.append(
(function(element[0]), function(element[1]), function(element[2]))
)
elif isinstance(element, dict):
new_dict = {}
for key, value in element.items():
new_dict[key] = function(value)
result_lst.append(new_dict)
return result_lst
4 changes: 2 additions & 2 deletions src/ptwt/continuous_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def wavefun(
"""Define a grid and evaluate the wavelet on it."""
length = 2**precision
# load the bounds from untyped pywt code.
lower_bound: float = float(self.lower_bound) # type: ignore
upper_bound: float = float(self.upper_bound) # type: ignore
lower_bound: float = float(self.lower_bound)
upper_bound: float = float(self.upper_bound)
grid = torch.linspace(
lower_bound,
upper_bound,
Expand Down
139 changes: 89 additions & 50 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from ._util import (
Wavelet,
_as_wavelet,
_fold_channels,
_fold_axes,
_get_len,
_is_dtype_supported,
_pad_symmetric,
_unfold_channels,
_unfold_axes,
)


Expand Down Expand Up @@ -217,47 +217,66 @@ def _adjust_padding_at_reconstruction(
return pad_end, pad_start


def _wavedec_fold_channels_1d(data: torch.Tensor) -> Tuple[torch.Tensor, List[int]]:
data = data.unsqueeze(-2)
ds = data.shape
data = _fold_channels(data)
return data, list(ds)
def _preprocess_tensor_dec1d(
data: torch.Tensor,
) -> Tuple[torch.Tensor, Union[List[int], None]]:
"""Preprocess input tensor dimensions.

Args:
data (torch.Tensor): An input tensor of any shape.

def _wavedec_unfold_channels_1d_list(
result_list: List[torch.Tensor], ds: List[int]
Returns:
Tuple[torch.Tensor, Union[List[int], None]]:
A data tensor of shape [new_batch, 1, to_process]
and the original shape, if the shape has changed.
"""
ds = None
if len(data.shape) == 1:
# assume time series
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 2:
# assume batched time series
data = data.unsqueeze(1)
else:
data, ds = _fold_axes(data, 1)
data = data.unsqueeze(1)
return data, ds


def _postprocess_result_list_dec1d(
result_lst: List[torch.Tensor], ds: List[int]
) -> List[torch.Tensor]:
unfold_res = []
for res_coeff in result_list:
unfold_res.append(
_unfold_channels(res_coeff.unsqueeze(1), list(ds)).squeeze(-2)
)
return unfold_res
# Unfold axes for the wavelets
unfold_list = []
for fres in result_lst:
unfold_list.append(_unfold_axes(fres, ds, 1))
return unfold_list


def _waverec_fold_channels_1d_list(
coeff_list: List[torch.Tensor],
def _preprocess_result_list_rec1d(
result_lst: List[torch.Tensor],
) -> Tuple[List[torch.Tensor], List[int]]:
folded = []
ds = coeff_list[0].unsqueeze(-2).shape
for to_fold_coeff in coeff_list:
folded.append(_fold_channels(to_fold_coeff.unsqueeze(-2)).squeeze(-2))
return folded, list(ds)
# Fold axes for the wavelets
fold_coeffs = []
ds = list(result_lst[0].shape)
for uf_coeff in result_lst:
f_coeff, _ = _fold_axes(uf_coeff, 1)
fold_coeffs.append(f_coeff)
return fold_coeffs, ds


def wavedec(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "reflect",
level: Optional[int] = None,
axis: int = -1,
) -> List[torch.Tensor]:
"""Compute the analysis (forward) 1d fast wavelet transform.

Args:
data (torch.Tensor): The input time series,
1d inputs are interpreted as ``[time]``,
2d inputs as ``[batch_size, time]``,
and 3d inputs as ``[batch_size, channels, time]``.
By default the last axis is transformed.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Please consider the output from ``pywt.wavelist(kind='discrete')``
Expand All @@ -274,9 +293,11 @@ def wavedec(
Zero padding pads zeros.
Constant padding replicates border values.
Periodic padding cyclically repeats samples.

level (int): The scale level to be computed.
Defaults to None.
axis (int): Compute the transform over this axis instead of the
last one. Defaults to -1.


Returns:
list: A list::
Expand All @@ -287,7 +308,8 @@ def wavedec(
approximation and D detail coefficients.

Raises:
ValueError: If the dtype of the input data tensor is unsupported.
ValueError: If the dtype of the input data tensor is unsupported or
if more than one axis is provided.

Example:
>>> import torch
Expand All @@ -300,21 +322,17 @@ def wavedec(
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>> mode='zero', level=2)
"""
fold = False
if data.dim() == 1:
# assume time series
data = data.unsqueeze(0).unsqueeze(0)
elif data.dim() == 2:
# assume batched time series
data = data.unsqueeze(1)
elif data.dim() == 3:
# assume batch, channels, time -> fold channels
fold = True
data, ds = _wavedec_fold_channels_1d(data)
if axis != -1:
if isinstance(axis, int):
data = data.swapaxes(axis, -1)
else:
raise ValueError("wavedec transforms a single axis only.")

if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

data, ds = _preprocess_tensor_dec1d(data)

dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
Expand All @@ -332,28 +350,38 @@ def wavedec(
res_lo, res_hi = torch.split(res, 1, 1)
result_list.append(res_hi.squeeze(1))
result_list.append(res_lo.squeeze(1))
result_list.reverse()

if ds:
result_list = _postprocess_result_list_dec1d(result_list, ds)

# unfold if necessary
if fold:
result_list = _wavedec_unfold_channels_1d_list(result_list, ds)
if axis != -1:
swap = []
for coeff in result_list:
swap.append(coeff.swapaxes(axis, -1))
result_list = swap

return result_list[::-1]
return result_list


def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.Tensor:
def waverec(
coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1
) -> torch.Tensor:
"""Reconstruct a signal from wavelet coefficients.

Args:
coeffs (list): The wavelet coefficient list produced by wavedec.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
axis (int): Transform this axis instead of the last one. Defaults to -1.

Returns:
torch.Tensor: The reconstructed signal.

Raises:
ValueError: If the dtype of the coeffs tensor is unsupported or if the
coefficients have incompatible shapes, dtypes or devices.
coefficients have incompatible shapes, dtypes or devices or if
more than one axis is provided.

Example:
>>> import torch
Expand All @@ -379,11 +407,19 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")

if axis != -1:
swap = []
if isinstance(axis, int):
for coeff in coeffs:
swap.append(coeff.swapaxes(axis, -1))
coeffs = swap
else:
raise ValueError("waverec transforms a single axis only.")

# fold channels, if necessary.
fold = False
if coeffs[0].dim() == 3:
fold = True
coeffs, ds = _waverec_fold_channels_1d_list(coeffs)
ds = None
if coeffs[0].dim() >= 3:
coeffs, ds = _preprocess_result_list_rec1d(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
Expand All @@ -408,7 +444,10 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
if padr > 0:
res_lo = res_lo[..., :-padr]

if fold:
res_lo = _unfold_channels(res_lo.unsqueeze(-2), list(ds)).squeeze(-2)
if ds:
res_lo = _unfold_axes(res_lo, ds, 1)

if axis != -1:
res_lo = res_lo.swapaxes(axis, -1)

return res_lo
Loading
Loading