Skip to content

Commit

Permalink
fix lint.
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lta committed Sep 22, 2023
1 parent 3958660 commit 3dd733a
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 27 deletions.
4 changes: 0 additions & 4 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""Utility methods to compute wavelet decompositions from a dataset."""
<<<<<<< HEAD
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union
=======
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union
>>>>>>> ac76b4478066c7891b94e00ed661712991214d5c

import numpy as np
import pywt
Expand Down
2 changes: 1 addition & 1 deletion src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _waverec_fold_channels_1d_list(
def _preprocess_tensor_dec1d(
data: torch.Tensor,
) -> Tuple[torch.Tensor, Union[List[int], None]]:
"""Preprocess input tensor dimensions
"""Preprocess input tensor dimensions.
Args:
data (torch.Tensor): An input tensor of any shape.
Expand Down
16 changes: 0 additions & 16 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,34 +259,18 @@ def waverec3(
if len(axes) != 3:
raise ValueError("3D transforms work with two axes")
else:
<<<<<<< HEAD
_check_axes_argument(axes)
=======
_check_axes_argument(list(axes))
>>>>>>> ac76b4478066c7891b94e00ed661712991214d5c
swap_axes_fn = partial(_swap_axes, axes=list(axes))
coeffs = _map_result(coeffs, swap_axes_fn)

wavelet = _as_wavelet(wavelet)
ds = None
# the Union[tensor, dict] idea is coming from pywt. We don't change it here.
<<<<<<< HEAD
res_lll = coeffs[0]
if not isinstance(res_lll, torch.Tensor):
raise ValueError(
"First element of coeffs must be the approximation coefficient tensor."
)

if len(res_lll.shape) >= 5:
coeffs, ds = _waverec3d_fold_channels_3d_list(coeffs)
res_lll = coeffs[0] # TODO: Check if this is tensor.
=======
res_lll = _check_if_tensor(coeffs[0])

if len(res_lll.shape) >= 5:
coeffs, ds = _waverec3d_fold_channels_3d_list(coeffs)
res_lll = _check_if_tensor(coeffs[0])
>>>>>>> ac76b4478066c7891b94e00ed661712991214d5c

torch_device = res_lll.device
torch_dtype = res_lll.dtype
Expand Down
12 changes: 6 additions & 6 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
boundary: str = "qr",
axis: int = -1,
axis: Optional[int] = -1,
) -> None:
"""Create a matrix-fwt object.
Expand Down Expand Up @@ -313,11 +313,11 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]:
Args:
input_signal (torch.Tensor): Batched input data.
An example shape could be ``[batch_size, time]``.
Inputs can have any dimension.
This transform affects the last axis by default.
Use the axis argument in the constructor to choose
another axis.
An example shape could be ``[batch_size, time]``.
Inputs can have any dimension.
This transform affects the last axis by default.
Use the axis argument in the constructor to choose
another axis.
Returns:
List[torch.Tensor]: A list with the coefficients for each scale.
Expand Down
1 change: 1 addition & 0 deletions tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_matrix1d_batch_channel(size):

@pytest.mark.parametrize("axis", (0, 1, 2, 3, 4))
def test_axis_1d(axis):
"""Ensure the axis argument is supported correctly."""
data = torch.randn(24, 24, 24, 24, 24).type(torch.float64)
matrix_wavedec = MatrixWavedec(wavelet="haar", level=3, axis=axis)
coeff = matrix_wavedec(data)
Expand Down

0 comments on commit 3dd733a

Please sign in to comment.