Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing and docstrings #87

Merged
merged 42 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8ff2c89
Exchange List with Sequence in args
felixblanke Jun 3, 2024
aacca4b
Use builtin list instead of List
felixblanke Jun 3, 2024
fa14797
Use builtin tuple instead of Tuple
felixblanke Jun 3, 2024
75c0bc0
Use builtin dict instead of Dict
felixblanke Jun 3, 2024
69c69a2
Change return types to tuple
felixblanke Jun 3, 2024
3126c89
Add function overloads
felixblanke Jun 3, 2024
46198d0
refactor _map_result
felixblanke Jun 3, 2024
168b4a0
More adaptions to the new types
felixblanke Jun 3, 2024
09811b0
Fix _map_result
felixblanke Jun 3, 2024
cb1b3db
Improve wavelet args for separable transforms
felixblanke Jun 3, 2024
310b1af
Rename wavelet coefficient types.
felixblanke Jun 3, 2024
c2dccd0
tighten return type
felixblanke Jun 3, 2024
57c1ab7
Format
felixblanke Jun 3, 2024
9950294
Address linter remarks
felixblanke Jun 3, 2024
c9af10c
Adopt flake8 rules recommended by black project
felixblanke Jun 3, 2024
df1884e
Remove matplotlib requirement
felixblanke Jun 3, 2024
414492d
Minor typing improvement
felixblanke Jun 4, 2024
da7ac93
Merge branch 'main' into improve-typing
cthoyt Jun 10, 2024
f495acc
Change error type to ValueError at input validation
felixblanke Jun 11, 2024
acd8769
Improve type hints
felixblanke Jun 11, 2024
6d60132
Fix axis typehint and add to docstr
felixblanke Jun 11, 2024
e035881
Move type aliases into public 'constants' module and add docstr.
felixblanke Jun 13, 2024
5051f13
Fixate some type aliases to not be resolved in docs.
felixblanke Jun 13, 2024
1c6b4ea
Improve docstrings
felixblanke Jun 13, 2024
23ed9d8
Replace redundant docstr info with refers to type alias
felixblanke Jun 14, 2024
1139876
Add some special member funcs to docs
felixblanke Jun 14, 2024
dd10c82
Fix favicon and move version module in API
felixblanke Jun 14, 2024
0f3d01a
Fix import
felixblanke Jun 14, 2024
cd4c131
Change detail coeff tuple to namedtuple
felixblanke Jun 14, 2024
b1c34d5
Change type str back to imperative mood
felixblanke Jun 14, 2024
3dcecbe
Merge branch 'main' into improve-typing
felixblanke Jun 14, 2024
d4ed327
Add comment clarification
felixblanke Jun 14, 2024
1991a07
Rename type aliases.
felixblanke Jun 14, 2024
3939070
Introduce WaveletCoeff2dSeparable alias
felixblanke Jun 14, 2024
b324728
Make cast in _map_result more narrow and fix tuple creation
felixblanke Jun 14, 2024
a55f223
Update typing in tests
felixblanke Jun 14, 2024
48e87a7
Also import 2d separable type in __init__
felixblanke Jun 14, 2024
4282c9d
Improve type alias docstr
felixblanke Jun 14, 2024
942ad60
Improve typing in JIT code
felixblanke Jun 14, 2024
b8a510a
Adapt right pad logic to avoid JIT tracer warning
felixblanke Jun 14, 2024
0ac20c5
Refactor of order methods for WaveletPacket2d
felixblanke Jun 14, 2024
c973350
Use mypy main branch for typing session for now
felixblanke Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ ignore =
# asserts are ok in test.
S101
C901
extend-select = B950
extend-ignore = E501,E701,E704
exclude =
.tox,
.git,
Expand All @@ -37,7 +39,7 @@ exclude =
.eggs,
data.
src/ptwt/__init__.py
max-line-length = 90
max-line-length = 80
max-complexity = 20
import-order-style = pycharm
application-import-names =
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ install_requires =
torch
scipy>=1.10
pooch
matplotlib
numpy
pytest
nox
Expand Down
12 changes: 7 additions & 5 deletions src/ptwt/_stationary_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module implements stationary wavelet transforms."""

from typing import List, Optional, Union
from collections.abc import Sequence
from typing import Optional, Union

import pywt
import torch
Expand All @@ -19,7 +20,7 @@ def _swt(
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
axis: Optional[int] = -1,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
v0lta marked this conversation as resolved.
Show resolved Hide resolved
"""Compute a multilevel 1d stationary wavelet transform.

Args:
Expand All @@ -28,7 +29,7 @@ def _swt(
level (Optional[int], optional): The number of levels to compute

Returns:
List[torch.Tensor]: Same as wavedec.
list[torch.Tensor]: Same as wavedec.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you go through these, it's better to remove the redundant type annotations inside the docstrings rather than update them

Suggested change
list[torch.Tensor]: Same as wavedec.
: Same as wavedec.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep the leading colons? The examples in Google's styleguide show none.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the documentation works without it, then no! I am not usually using this code style so I don't know all the specifics

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the docstrings and cleaned them up. This should be addressed now.

The styleguide also writes about function arguments: "The description should include required type(s) if the code does not contain a corresponding type annotation." Opinions on this matter?

Equivalent to pywt.swt with trim_approx=True.

Raises:
Expand Down Expand Up @@ -107,14 +108,15 @@ def _conv_transpose_dedilate(


def _iswt(
coeffs: List[torch.Tensor],
coeffs: Sequence[torch.Tensor],
wavelet: Union[pywt.Wavelet, str],
axis: Optional[int] = -1,
) -> torch.Tensor:
"""Inverts a 1d stationary wavelet transform.

Args:
coeffs (List[torch.Tensor]): The coefficients as computed by the swt function.
coeffs (Sequence[torch.Tensor]): The coefficients as computed
by the swt function.
wavelet (Union[pywt.Wavelet, str]): The wavelet used for the forward transform.
axis (int, optional): The axis the forward trasform was computed over.
Defaults to -1.
Expand Down
95 changes: 66 additions & 29 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Utility methods to compute wavelet decompositions from a dataset."""

import typing
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Protocol, Union, cast, overload

import numpy as np
import pywt
import torch
from typing_extensions import Unpack

from .constants import OrthogonalizeMethod

Expand All @@ -20,7 +22,7 @@ class Wavelet(Protocol):
rec_hi: Sequence[float]
dec_len: int
rec_len: int
filter_bank: Tuple[
filter_bank: tuple[
Sequence[float], Sequence[float], Sequence[float], Sequence[float]
]

Expand All @@ -29,6 +31,15 @@ def __len__(self) -> int:
return len(self.dec_lo)


WaveletDetailTuple2d = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's useful to add a docstring to each of these tuples to give some additional context about what they are and where they're used

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, this could be a place to use a namedtuple to make it even more explicit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a good idea 👍

WaveletDetailDict = dict[str, torch.Tensor]

WaveletCoeffDetailTuple2d = tuple[
torch.Tensor, Unpack[tuple[WaveletDetailTuple2d, ...]]
]
WaveletCoeffDetailDict = tuple[torch.Tensor, Unpack[tuple[WaveletDetailDict, ...]]]
felixblanke marked this conversation as resolved.
Show resolved Hide resolved


def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
"""Ensure the input argument to be a pywt wavelet compatible object.

Expand Down Expand Up @@ -63,15 +74,15 @@ def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a_mul * b_mul


def _get_len(wavelet: Union[Tuple[torch.Tensor, ...], str, Wavelet]) -> int:
def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int:
"""Get number of filter coefficients for various wavelet data types."""
if isinstance(wavelet, tuple):
return wavelet[0].shape[0]
else:
return len(_as_wavelet(wavelet))


def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch.Tensor:
def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.Tensor:
padl, padr = pad_list
dimlen = signal.shape[0]
if padl > dimlen or padr > dimlen:
Expand All @@ -88,11 +99,11 @@ def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch.
cat_list.insert(0, signal[:padl].flip(0))
if padr > 0:
cat_list.append(signal[-padr::].flip(0))
return torch.cat(cat_list, axis=0) # type: ignore
return torch.cat(cat_list, dim=0)


def _pad_symmetric(
signal: torch.Tensor, pad_lists: List[Tuple[int, int]]
signal: torch.Tensor, pad_lists: Sequence[tuple[int, int]]
) -> torch.Tensor:
if len(signal.shape) < len(pad_lists):
raise ValueError("not enough dimensions to pad.")
Expand All @@ -106,15 +117,15 @@ def _pad_symmetric(
return signal


def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int]]:
def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int]]:
"""Fold unchanged leading dimensions into a single batch dimension.

Args:
data ( torch.Tensor): The input data array.
data (torch.Tensor): The input data array.
keep_no (int): The number of dimensions to keep.

Returns:
Tuple[ torch.Tensor, List[int]]:
tuple[torch.Tensor, list[int]]:
The folded result array, and the shape of the original input.
"""
dshape = list(data.shape)
Expand All @@ -124,7 +135,7 @@ def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int
)


def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor:
def _unfold_axes(data: torch.Tensor, ds: list[int], keep_no: int) -> torch.Tensor:
"""Unfold i.e. [batch*channel,height,widht] to [batch,channel,height,width]."""
return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:]))

Expand All @@ -137,49 +148,75 @@ def _check_if_tensor(array: Any) -> torch.Tensor:
return array


def _check_axes_argument(axes: List[int]) -> None:
def _check_axes_argument(axes: Sequence[int]) -> None:
if len(set(axes)) != len(axes):
raise ValueError("Cant transform the same axis twice.")


def _get_transpose_order(
axes: List[int], data_shape: List[int]
) -> Tuple[List[int], List[int]]:
axes: Sequence[int], data_shape: Sequence[int]
) -> tuple[list[int], list[int]]:
axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes))
all_axes = list(range(len(data_shape)))
remove_transformed = list(filter(lambda a: a not in axes, all_axes))
return remove_transformed, axes


def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
def _swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
return torch.permute(data, front + back)


def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
_check_axes_argument(axes)
front, back = _get_transpose_order(axes, list(data.shape))
restore_sorted = torch.argsort(torch.tensor(front + back)).tolist()
return torch.permute(data, restore_sorted)


@overload
def _map_result(
data: WaveletCoeffDetailTuple2d,
function: Callable[[torch.Tensor], torch.Tensor],
) -> WaveletCoeffDetailTuple2d: ...


@overload
def _map_result(
data: List[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any
function: Callable[[Any], torch.Tensor],
) -> List[Union[torch.Tensor, Any]]:
# Apply the given function to the input list of tensor and tuples.
result_lst: List[Union[torch.Tensor, Any]] = []
for element in data:
if isinstance(element, torch.Tensor):
result_lst.append(function(element))
elif isinstance(element, tuple):
data: WaveletCoeffDetailDict,
function: Callable[[torch.Tensor], torch.Tensor],
) -> WaveletCoeffDetailDict: ...


def _map_result(
data: Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict],
function: Callable[[torch.Tensor], torch.Tensor],
) -> Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict]:
approx = function(data[0])
result_lst: list[
Union[
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
dict[str, torch.Tensor],
]
] = []
for element in data[1:]:
if isinstance(element, tuple):
result_lst.append(
(function(element[0]), function(element[1]), function(element[2]))
(
function(element[0]),
function(element[1]),
function(element[2]),
)
)
elif isinstance(element, dict):
new_dict = {}
for key, value in element.items():
new_dict[key] = function(value)
new_dict = {key: function(value) for key, value in element.items()}
result_lst.append(new_dict)
return result_lst
else:
raise AssertionError(f"Unexpected input type {type(element)}")
felixblanke marked this conversation as resolved.
Show resolved Hide resolved

return_val = approx, *result_lst
return_val = cast(
Union[WaveletCoeffDetailTuple2d, WaveletCoeffDetailDict], return_val
)
return return_val
8 changes: 4 additions & 4 deletions src/ptwt/continuous_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module is based on pywt's cwt implementation.
"""

from typing import Any, Tuple, Union
from typing import Any, Union

import numpy as np
import torch
Expand All @@ -27,7 +27,7 @@ def cwt(
scales: Union[np.ndarray, torch.Tensor], # type: ignore
wavelet: Union[ContinuousWavelet, str],
sampling_period: float = 1.0,
) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore
) -> tuple[torch.Tensor, np.ndarray]: # type: ignore
"""Compute the single-dimensional continuous wavelet transform.

This function is a PyTorch port of pywt.cwt as found at:
Expand All @@ -50,7 +50,7 @@ def cwt(
ValueError: If a scale is too small for the input signal.

Returns:
Tuple[torch.Tensor, np.ndarray]: The first tuple-element contains
tuple[torch.Tensor, np.ndarray]: The first tuple-element contains
the transformation matrix of shape [scales, batch, time].
The second element contains an array with frequency information.

Expand Down Expand Up @@ -267,7 +267,7 @@ def center(self) -> torch.Tensor:

def wavefun(
self, precision: int, dtype: torch.dtype = torch.float64
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Define a grid and evaluate the wavelet on it."""
length = 2**precision
# load the bounds from untyped pywt code.
Expand Down
Loading
Loading