Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jun 3, 2024
1 parent c2dccd0 commit 57c1ab7
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 36 deletions.
30 changes: 15 additions & 15 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Utility methods to compute wavelet decompositions from a dataset."""

from collections.abc import Sequence
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Protocol, Union, cast, overload
from typing_extensions import Unpack

import numpy as np
import pywt
import torch
from typing_extensions import Unpack

from .constants import OrthogonalizeMethod

Expand All @@ -30,10 +30,13 @@ def __len__(self) -> int:
"""Return the number of filter coefficients."""
return len(self.dec_lo)


WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
WaveletDetailDict = dict[str, torch.Tensor]

WaveletCoeffDetailTuple2d = tuple[torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]]
WaveletCoeffDetailTuple2d = tuple[
torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]
]
WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]]


Expand Down Expand Up @@ -96,7 +99,7 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.
cat_list.insert(0, signal[:padl].flip(0))
if padr > 0:
cat_list.append(signal[-padr::].flip(0))
return torch.cat(cat_list, axis=0) # type: ignore
return torch.cat(cat_list, dim=0)


def _pad_symmetric(
Expand All @@ -118,7 +121,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int
"""Fold unchanged leading dimensions into a single batch dimension.
Args:
data ( torch.Tensor): The input data array.
data (torch.Tensor): The input data array.
keep_no (int): The number of dimensions to keep.
Returns:
Expand Down Expand Up @@ -176,21 +179,19 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
def _map_result(
data: WaveletCoeffDetailTuple2d,
function: Callable[[torch.Tensor], torch.Tensor],
) -> WaveletCoeffDetailTuple2d:
...
) -> WaveletCoeffDetailTuple2d: ...


@overload
def _map_result(
data: WaveletCoeffDetailDict,
function: Callable[[torch.Tensor], torch.Tensor],
) -> WaveletCoeffDetailDict:
...
) -> WaveletCoeffDetailDict: ...


def _map_result(
data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict],
function: Callable[[torch.Tensor], torch.Tensor]
function: Callable[[torch.Tensor], torch.Tensor],
) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]:
return_tuple = isinstance(data, tuple)
approx = function(data[0])
Expand All @@ -210,14 +211,13 @@ def _map_result(
)
)
elif isinstance(element, dict):
new_dict = {
key: function(value)
for key, value in element.items()
}
new_dict = {key: function(value) for key, value in element.items()}
result_lst.append(new_dict)
else:
raise AssertionError(f"Unexpected input type {type(element)}")

return_val = approx, *result_lst
return_val = cast(Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val)
return_val = cast(
Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val
)
return return_val
1 change: 1 addition & 0 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module treats boundaries with edge-padding.
"""

from collections.abc import Sequence
from typing import Optional, Union, cast

Expand Down
4 changes: 3 additions & 1 deletion src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def _fwt_pad2(
return data_pad


def _waverec2d_fold_channels_2d_list(coeffs: WaveletCoeffDetailTuple2d) -> tuple[WaveletCoeffDetailTuple2d, list[int]]:
def _waverec2d_fold_channels_2d_list(
coeffs: WaveletCoeffDetailTuple2d,
) -> tuple[WaveletCoeffDetailTuple2d, list[int]]:
# fold the input coefficients for processing conv2d_transpose.
ds = list(_check_if_tensor(coeffs[0]).shape)
return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds
Expand Down
5 changes: 1 addition & 4 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,7 @@ def _waverec3d_fold_channels_3d_list(
fold_coeffs: list[dict[str, torch.Tensor]] = []
ds = list(_check_if_tensor(coeffs[0]).shape)
fold_coeffs = [
{
key: _fold_axes(value, 3)[0]
for key, value in coeff.items()
}
{key: _fold_axes(value, 3)[0] for key, value in coeff.items()}
for coeff in coeffs[1:]
]
return (fold_approx_coeff, *fold_coeffs), ds
Expand Down
4 changes: 1 addition & 3 deletions src/ptwt/matmul_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,7 @@ def _construct_analysis_matrices(
current_width = current_width // 2
self.size_list.append((current_height, current_width))

def __call__(
self, input_signal: torch.Tensor
) -> WaveletCoeffDetailTuple2d:
def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d:
"""Compute the fwt for the given input signal.
The fwt matrix is set up during the first call
Expand Down
8 changes: 2 additions & 6 deletions src/ptwt/matmul_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def _construct_analysis_matrices(
)
self.size_list.append((current_depth, current_height, current_width))

def __call__(
self, input_signal: torch.Tensor
) -> WaveletCoeffDetailDict:
def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict:
"""Compute a separable 3d-boundary wavelet transform.
Args:
Expand Down Expand Up @@ -388,9 +386,7 @@ def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Ten
return cat_tensor
return self._cat_coeff_recursive(done_dict)

def __call__(
self, coefficients: WaveletCoeffDetailDict
) -> torch.Tensor:
def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor:
"""Reconstruct a batched 3d-signal from its coefficients.
Args:
Expand Down
18 changes: 12 additions & 6 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import pywt
import torch

from ._util import Wavelet, WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict, _as_wavelet
from ._util import (
Wavelet,
WaveletCoeffDetailDict,
WaveletCoeffDetailTuple2d,
_as_wavelet,
)
from .constants import ExtendedBoundaryMode, OrthogonalizeMethod
from .conv_transform import wavedec, waverec
from .conv_transform_2 import wavedec2, waverec2
Expand Down Expand Up @@ -384,7 +389,8 @@ def get_natural_order(self, level: int) -> list[str]:
return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)]

def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[
[torch.Tensor], WaveletCoeffDetailTuple2d,
[torch.Tensor],
WaveletCoeffDetailTuple2d,
]:
if self.mode == "boundary":
shape = tuple(shape)
Expand Down Expand Up @@ -413,7 +419,9 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[
wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes
)

def _get_waverec(self, shape: tuple[int, ...]) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]:
def _get_waverec(
self, shape: tuple[int, ...]
) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]:
if self.mode == "boundary":
shape = tuple(shape)
if shape not in self.matrix_waverec2_dict.keys():
Expand Down Expand Up @@ -450,9 +458,7 @@ def _transform_tuple_to_fsdict_func(
self,
fsdict_func: Callable[[WaveletCoeffDetailDict], torch.Tensor],
) -> Callable[[WaveletCoeffDetailTuple2d], torch.Tensor]:
def _fsdict_func(
coeffs: WaveletCoeffDetailTuple2d
) -> torch.Tensor:
def _fsdict_func(coeffs: WaveletCoeffDetailTuple2d) -> torch.Tensor:
# assert for type checking
assert len(coeffs) == 2
a, (h, v, d) = coeffs
Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/separable_conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from ._util import (
Wavelet,
WaveletCoeffDetailTuple2d,
WaveletCoeffDetailDict,
WaveletCoeffDetailTuple2d,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
Expand Down

0 comments on commit 57c1ab7

Please sign in to comment.