Skip to content

Commit

Permalink
Merge pull request #87 from v0lta/improve-typing
Browse files Browse the repository at this point in the history
Improve typing and docstrings
  • Loading branch information
v0lta authored Jun 18, 2024
2 parents 78abc5f + c973350 commit c256d9e
Show file tree
Hide file tree
Showing 36 changed files with 900 additions and 882 deletions.
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ ignore =
# asserts are ok in test.
S101
C901
extend-select = B950
extend-ignore = E501,E701,E704
exclude =
.tox,
.git,
Expand All @@ -37,7 +39,7 @@ exclude =
.eggs,
data.
src/ptwt/__init__.py
max-line-length = 90
max-line-length = 80
max-complexity = 20
import-order-style = pycharm
application-import-names =
Expand Down
9 changes: 7 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@
html_favicon = "_static/favicon.ico"
html_logo = "_static/shannon.png"

html_favicon = "favicon/favicon.ico"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
Expand All @@ -82,3 +80,10 @@

# numbered figures
numfig = True

autodoc_type_aliases = {
"WaveletCoeff2d": "ptwt.constants.WaveletCoeff2d",
"WaveletCoeff2dSeparable": "ptwt.constants.WaveletCoeff2dSeparable",
"WaveletCoeffNd": "ptwt.constants.WaveletCoeffNd",
"BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec",
}
20 changes: 12 additions & 8 deletions docs/ptwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ ptwt.packets module

.. automodule:: ptwt.packets
:members:
:special-members: __getitem__
:undoc-members:
:show-inheritance:

Expand Down Expand Up @@ -68,6 +69,7 @@ ptwt.matmul\_transform module

.. automodule:: ptwt.matmul_transform
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

Expand All @@ -76,6 +78,7 @@ ptwt.matmul\_transform\_2 module

.. automodule:: ptwt.matmul_transform_2
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

Expand All @@ -84,6 +87,7 @@ ptwt.matmul\_transform\_3 module

.. automodule:: ptwt.matmul_transform_3
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

Expand All @@ -96,14 +100,6 @@ ptwt.sparse\_math module
:undoc-members:
:show-inheritance:

ptwt.version module
-------------------

.. automodule:: ptwt.version
:members:
:undoc-members:
:show-inheritance:

ptwt.wavelets\_learnable module
-------------------------------

Expand All @@ -118,3 +114,11 @@ ptwt.constants
:members:
:undoc-members:
:show-inheritance:

ptwt.version module
-------------------

.. automodule:: ptwt.version
:members:
:undoc-members:
:show-inheritance:
47 changes: 4 additions & 43 deletions examples/deepfake_analysis/packet_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,6 @@
import ptwt


def get_freq_order(level: int):
"""Get the frequency order for a given packet decomposition level.
Adapted from:
https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
The code elements denote the filter application order. The filters
are named following the pywt convention as:
a - LL, low-low coefficients
h - LH, low-high coefficients
v - HL, high-low coefficients
d - HH, high-high coefficients
"""
wp_natural_path = list(product(["a", "h", "v", "d"], repeat=level))

def _get_graycode_order(level, x="a", y="d"):
graycode_order = [x, y]
for _ in range(level - 1):
graycode_order = [x + path for path in graycode_order] + [
y + path for path in graycode_order[::-1]
]
return graycode_order

def _expand_2d_path(path):
expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"}
return (
"".join([expanded_paths[p][0] for p in path]),
"".join([expanded_paths[p][1] for p in path]),
)

nodes: dict = {}
for (row_path, col_path), node in [
(_expand_2d_path(node), node) for node in wp_natural_path
]:
nodes.setdefault(row_path, {})[col_path] = node
graycode_order = _get_graycode_order(level, x="l", y="h")
nodes_list: list = [nodes[path] for path in graycode_order if path in nodes]
wp_frequency_path = []
for row in nodes_list:
wp_frequency_path.append([row[path] for path in graycode_order if path in row])
return wp_frequency_path, wp_natural_path


def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
"""Create a ready-to-polt image with frequency-order packages.
Given a packet array in natural order, creat an image which is
Expand All @@ -63,7 +22,8 @@ def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
Returns:
[np.ndarray]: The image of shape [original_height, original_width]
"""
wp_freq_path, wp_natural_path = get_freq_order(degree)
wp_freq_path = ptwt.WaveletPacket2D.get_freq_order(degree)
wp_natural_path = ptwt.WaveletPacket2D.get_natural_order(degree)

image = []
# go through the rows.
Expand Down Expand Up @@ -107,7 +67,8 @@ def load_images(path: str) -> list:


if __name__ == "__main__":
frequency_path, natural_path = get_freq_order(level=3)
freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3)
frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
print("Loading ffhq images:")
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
print("processing ffhq")
Expand Down
24 changes: 3 additions & 21 deletions examples/speed_tests/timeitconv_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,8 @@
import ptwt


class WaveletTuple(NamedTuple):
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""

dec_lo: torch.Tensor
dec_hi: torch.Tensor
rec_lo: torch.Tensor
rec_hi: torch.Tensor


def _set_up_wavelet_tuple(wavelet, dtype):
return WaveletTuple(
torch.tensor(wavelet.dec_lo).type(dtype),
torch.tensor(wavelet.dec_hi).type(dtype),
torch.tensor(wavelet.rec_lo).type(dtype),
torch.tensor(wavelet.rec_hi).type(dtype),
)


def _jit_wavedec_fun(data, wavelet):
return ptwt.wavedec(data, wavelet, "periodic", level=10)
return ptwt.wavedec(data, wavelet, mode="periodic", level=10)


if __name__ == "__main__":
Expand Down Expand Up @@ -56,7 +38,7 @@ def _jit_wavedec_fun(data, wavelet):
end = time.perf_counter()
ptwt_time_cpu.append(end - start)

wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
jit_wavedec = torch.jit.trace(
_jit_wavedec_fun,
(data, wavelet),
Expand All @@ -81,7 +63,7 @@ def _jit_wavedec_fun(data, wavelet):
end = time.perf_counter()
ptwt_time_gpu.append(end - start)

wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
jit_wavedec = torch.jit.trace(
_jit_wavedec_fun,
(data.cuda(), wavelet),
Expand Down
24 changes: 3 additions & 21 deletions examples/speed_tests/timeitconv_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,9 @@
import ptwt


class WaveletTuple(NamedTuple):
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""

dec_lo: torch.Tensor
dec_hi: torch.Tensor
rec_lo: torch.Tensor
rec_hi: torch.Tensor


def _set_up_wavelet_tuple(wavelet, dtype):
return WaveletTuple(
torch.tensor(wavelet.dec_lo).type(dtype),
torch.tensor(wavelet.dec_hi).type(dtype),
torch.tensor(wavelet.rec_lo).type(dtype),
torch.tensor(wavelet.rec_hi).type(dtype),
)


def _to_jit_wavedec_2(data, wavelet):
def _to_jit_wavedec_2(data: torch.Tensor, wavelet) -> list[torch.Tensor]:
"""Ensure uniform datatypes in lists for the tracer.
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
Going from list[Union[torch.Tensor, list[torch.Tensor]]] to list[torch.Tensor]
means we have to stack the lists in the output.
"""
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
Expand Down Expand Up @@ -79,7 +61,7 @@ def _to_jit_wavedec_2(data, wavelet):

ptwt_time_gpu.append(end - start)

wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
jit_wavedec = torch.jit.trace(
_to_jit_wavedec_2,
(data.cuda(), wavelet),
Expand Down
22 changes: 2 additions & 20 deletions examples/speed_tests/timeitconv_2d_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,13 @@
import ptwt


class WaveletTuple(NamedTuple):
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""

dec_lo: torch.Tensor
dec_hi: torch.Tensor
rec_lo: torch.Tensor
rec_hi: torch.Tensor


def _set_up_wavelet_tuple(wavelet, dtype):
return WaveletTuple(
torch.tensor(wavelet.dec_lo).type(dtype),
torch.tensor(wavelet.dec_hi).type(dtype),
torch.tensor(wavelet.rec_lo).type(dtype),
torch.tensor(wavelet.rec_hi).type(dtype),
)


def _to_jit_wavedec_2(data, wavelet):
"""Ensure uniform datatypes in lists for the tracer.
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
means we have to stack the lists in the output.
"""
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
coeff = ptwt.fswavedec2(data, wavelet, "reflect", level=5)
coeff = ptwt.fswavedec2(data, wavelet, mode="reflect", level=5)
coeff2 = []
for c in coeff:
if isinstance(c, torch.Tensor):
Expand Down Expand Up @@ -103,7 +85,7 @@ def _to_jit_wavedec_2(data, wavelet):
end = time.perf_counter()
ptwt_time_gpu.append(end - start)

wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
jit_wavedec = torch.jit.trace(
_to_jit_wavedec_2,
(data.cuda(), wavelet),
Expand Down
20 changes: 1 addition & 19 deletions examples/speed_tests/timeitconv_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,6 @@
import ptwt


class WaveletTuple(NamedTuple):
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""

dec_lo: torch.Tensor
dec_hi: torch.Tensor
rec_lo: torch.Tensor
rec_hi: torch.Tensor


def _set_up_wavelet_tuple(wavelet, dtype):
return WaveletTuple(
torch.tensor(wavelet.dec_lo).type(dtype),
torch.tensor(wavelet.dec_hi).type(dtype),
torch.tensor(wavelet.rec_lo).type(dtype),
torch.tensor(wavelet.rec_hi).type(dtype),
)


def _to_jit_wavedec_3(data, wavelet):
"""Ensure uniform datatypes in lists for the tracer.
Expand Down Expand Up @@ -85,7 +67,7 @@ def _to_jit_wavedec_3(data, wavelet):
end = time.perf_counter()
ptwt_time_gpu.append(end - start)

wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
jit_wavedec = torch.jit.trace(
_to_jit_wavedec_3,
(data.cuda(), wavelet),
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ tests =
# pooch is an optional scipy dependency for getting datasets
pooch
typing =
mypy
mypy @ git+https://github.com/python/mypy
# needed otherwise pytest decorators don't get typed properly
pytest
examples =
Expand Down
3 changes: 2 additions & 1 deletion src/ptwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""

from ._util import Wavelet
from ._util import Wavelet, WaveletTensorTuple
from .constants import WaveletCoeff2d, WaveletCoeff2dSeparable, WaveletCoeffNd
from .continuous_transform import cwt
from .conv_transform import wavedec, waverec
from .conv_transform_2 import wavedec2, waverec2
Expand Down
Loading

0 comments on commit c256d9e

Please sign in to comment.