Skip to content

Commit

Permalink
Address linter remarks
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jun 3, 2024
1 parent 57c1ab7 commit 9950294
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 26 deletions.
3 changes: 2 additions & 1 deletion src/ptwt/_stationary_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def _iswt(
"""Inverts a 1d stationary wavelet transform.
Args:
coeffs (Sequence[torch.Tensor]): The coefficients as computed by the swt function.
coeffs (Sequence[torch.Tensor]): The coefficients as computed
by the swt function.
wavelet (Union[pywt.Wavelet, str]): The wavelet used for the forward transform.
axis (int, optional): The axis the forward trasform was computed over.
Defaults to -1.
Expand Down
1 change: 0 additions & 1 deletion src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def _map_result(
data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict],
function: Callable[[torch.Tensor], torch.Tensor],
) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]:
return_tuple = isinstance(data, tuple)
approx = function(data[0])
result_lst: list[
Union[
Expand Down
8 changes: 4 additions & 4 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
torch.nn.functional.conv_transpose2d under the hood.
"""

from collections.abc import Sequence
from functools import partial
from typing import Optional, Union, cast

Expand Down Expand Up @@ -169,7 +168,7 @@ def wavedec2(
last two. Defaults to (-2, -1).
Returns:
WaveletTransformReturn2d: A tuple containing the wavelet coefficients.
WaveletCoeffDetailTuple2d: A tuple containing the wavelet coefficients.
The coefficients are in pywt order. That is::
[cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .
Expand Down Expand Up @@ -248,8 +247,9 @@ def waverec2(
or forward transform by running transposed convolutions.
Args:
coeffs (WaveletTransformReturn2d): The wavelet coefficient tupl produced by wavedec2.
The coefficients must be in pywt order. That is::
coeffs (WaveletCoeffDetailTuple2d): The wavelet coefficient tuple
produced by wavedec2. The coefficients must be in pywt order.
That is::
[cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .
Expand Down
5 changes: 3 additions & 2 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def wavedec3(
instead of the last three. Defaults to (-3, -2, -1).
Returns:
WaveletTransformReturn3d: A tuple with the lll coefficients and
WaveletCoeffDetailDict: A tuple with the lll coefficients and
dictionaries with the filter order strings::
("aad", "ada", "add", "daa", "dad", "dda", "ddd")
Expand Down Expand Up @@ -236,7 +236,8 @@ def waverec3(
"""Reconstruct a signal from wavelet coefficients.
Args:
coeffs (WaveletTransformReturn3d): The wavelet coefficient tuple produced by wavedec3.
coeffs (WaveletCoeffDetailDict): The wavelet coefficient tuple
produced by wavedec3.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
axes (tuple[int, int, int]): Transform these axes instead of the
Expand Down
6 changes: 3 additions & 3 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import sys
from collections import Sequence
from collections.abc import Sequence
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -600,8 +600,8 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor:
"""Run the synthesis or inverse matrix fwt.
Args:
coefficients (Sequence[torch.Tensor]): The coefficients produced by the forward
transform.
coefficients (Sequence[torch.Tensor]): The coefficients produced
by the forward transform.
Returns:
torch.Tensor: The input signal reconstruction.
Expand Down
7 changes: 3 additions & 4 deletions src/ptwt/matmul_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import sys
from collections.abc import Sequence
from functools import partial
from typing import Optional, Union, cast

Expand Down Expand Up @@ -430,8 +429,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailTuple2d:
This transform affects the last two dimensions.
Returns:
(WaveletTransformReturn2d): The resulting coefficients per level are stored in
a pywt style tuple. The tuple is ordered as::
(WaveletCoeffDetailTuple2d): The resulting coefficients per level
are stored in a pywt style tuple. The tuple is ordered as::
(ll, (lh, hl, hh), ...)
Expand Down Expand Up @@ -729,7 +728,7 @@ def __call__(
"""Compute the inverse matrix 2d fast wavelet transform.
Args:
coefficients (WaveletTransformReturn2d): The coefficient tuple as returned
coefficients (WaveletCoeffDetailTuple2d): The coefficient tuple as returned
by the `MatrixWavedec2`-Object.
Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/ptwt/matmul_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffDetailDict:
ValueError: If the input dimensions don't work.
Returns:
WaveletTransformReturn3d:
WaveletCoeffDetailDict:
A tuple with the approximation coefficients,
and a coefficient dict for each scale.
"""
Expand Down Expand Up @@ -390,7 +390,7 @@ def __call__(self, coefficients: WaveletCoeffDetailDict) -> torch.Tensor:
"""Reconstruct a batched 3d-signal from its coefficients.
Args:
coefficients (WaveletTransformReturn3d):
coefficients (WaveletCoeffDetailDict):
The output from MatrixWavedec3, consisting of a tuple
of the approximation coefficients and a dict with the
detail coefficients for each scale.
Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
from typing import TYPE_CHECKING, Callable, Optional, Union

import numpy as np
import pywt
Expand Down
14 changes: 6 additions & 8 deletions src/ptwt/separable_conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
from typing import Optional, Union

import numpy as np
import pywt
import torch

from ._util import (
Wavelet,
WaveletCoeffDetailDict,
WaveletCoeffDetailTuple2d,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
Expand Down Expand Up @@ -127,7 +125,7 @@ def _separable_conv_wavedecn(
level (int): The desired decomposition level.
Returns:
WaveletTransformReturn3d: A tuple with the approximation coefficients,
WaveletCoeffDetailDict: A tuple with the approximation coefficients,
and a coefficient dict for each scale.
"""
result: list[dict[str, torch.Tensor]] = []
Expand Down Expand Up @@ -156,7 +154,7 @@ def _separable_conv_waverecn(
"""Separable n-dimensional wavelet synthesis transform.
Args:
coeffs (WaveletTransformReturn3d):
coeffs (WaveletCoeffDetailDict):
The output as produced by `_separable_conv_wavedecn`.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet, as used by `_separable_conv_wavedecn`.
Expand Down Expand Up @@ -208,7 +206,7 @@ def fswavedec2(
ValueError: If the data is not a batched 2D signal.
Returns:
WaveletTransformReturn3d:
WaveletCoeffDetailDict:
A tuple with the lll coefficients and dictionaries
with the filter order strings::
Expand Down Expand Up @@ -277,7 +275,7 @@ def fswavedec3(
ValueError: If the input is not a batched 3D signal.
Returns:
WaveletTransformReturn3d:
WaveletCoeffDetailDict:
A tuple with the lll coefficients and dictionaries
with the filter order strings::
Expand Down Expand Up @@ -332,7 +330,7 @@ def fswaverec2(
the hood.
Args:
coeffs (WaveletTransformReturn3d):
coeffs (WaveletCoeffDetailDict):
The wavelet coefficients as computed by `fswavedec2`.
wavelet (Wavelet or str): The wavelet to use for the
synthesis transform.
Expand Down Expand Up @@ -396,7 +394,7 @@ def fswaverec3(
"""Compute a fully separable 3D-padded synthesis wavelet transform.
Args:
coeffs (WaveletTransformReturn3d):
coeffs (WaveletCoeffDetailDict):
The wavelet coefficients as computed by `fswavedec3`.
wavelet (Wavelet or str): The wavelet to use for the
synthesis transform.
Expand Down

0 comments on commit 9950294

Please sign in to comment.