Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWT: Make circular padding wrap more than once if needed #84

Merged
merged 17 commits into from
Jun 13, 2024

Conversation

NiclasPi
Copy link
Collaborator

Some combinations of wavelets and number of decomposition steps cause the following RuntimeError: "Padding value causes wrapping around more than once".
If the padding values are greater than the input length, torch.nn.functional.pad does not wrap around more than once in circluar mode. We need to manually wrap around more than once, if needed. See feature request https://github.com/pytorch/pytorch/issues/57911 in PyTorch.

@v0lta v0lta added bug Something isn't working enhancement New feature or request labels Apr 30, 2024
@NiclasPi
Copy link
Collaborator Author

NiclasPi commented Jun 3, 2024

Thank you for the review @cthoyt ! I added explainig comments and some tests comparing the function with numpys wrap padding.

@v0lta
Copy link
Owner

v0lta commented Jun 6, 2024

I see failed tests on my machine:
tests/test_swt.py .....FFF.........FFF.........FFF.........FFF................................................................... [ 98%] ........

test_swt_1d[db1-size0-1], test_swt_1d[db1-size0-2], test_swt_1d[db1-size0-None], test_swt_1d[db2-size0-1], test_swt_1d[db2-size0-2], test_swt_1d[db2-size0-None] and others dont pass.
I will look into this tomorrow. Perhaps I can figure out whats going on. @NiclasPi does everything pass on your machine?

@v0lta
Copy link
Owner

v0lta commented Jun 6, 2024

FAILED tests/test_swt.py::test_swt_1d[db1-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-None] - assert False

a partial log from nox -s test

@cthoyt
Copy link
Collaborator

cthoyt commented Jun 6, 2024

@v0lta we need to update this line

to make tests run on external PRs, it should look like this https://github.com/biopragmatics/curies/blob/470f71a69264c17260823d485de113695076a38f/.github/workflows/tests.yml#L3-L7

@v0lta
Copy link
Owner

v0lta commented Jun 7, 2024

Done in the main branch.

@cthoyt
Copy link
Collaborator

cthoyt commented Jun 7, 2024

@v0lta @NiclasPi the tests are now showing up properly in the PR

@v0lta
Copy link
Owner

v0lta commented Jun 7, 2024

TODO: Port shape fold-unfolding code to swt, properly.

@NiclasPi
Copy link
Collaborator Author

Sorry, I forgot to mention that I added test cases that weren't covered yet. Specifically, since I am working with 1-dimensional signals, I added cases with shape = (N,). The code base currently only tests 1-dimensional signals with shape = (1, N).
After closer inspection it turned out to be a problem in the folding-unfolding helper functions. These should also work with shape = (N,). @v0lta is going to look into this issue and will fix it. Afterwards, the failing tests for the SWT should pass.

@v0lta
Copy link
Owner

v0lta commented Jun 10, 2024

Thanks @NiclasPi . I have a fix in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding which starts to address this. We currently get (1, T) outputs shapes for (N,) inputs across the entire toolbox. With N, the input measurements and T for the dimension at some transform level. This is not what we want since the Toolbox should always respect the uses choice of input dimension. I will fix this across the board.
@cthoyt how did you [Merge branch 'main' into pr/84]. My changes ended up in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding , unfortunately. I think I should know this but here we are ;-).

@v0lta
Copy link
Owner

v0lta commented Jun 10, 2024

It's not a catastrophic problem users can currently fix this by running out.squeeze() themselves, but since we want to be pywt-compatible the shapes should be identical, so we'll fix this.

@cthoyt
Copy link
Collaborator

cthoyt commented Jun 10, 2024

you need to add the remote (see https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories)

git remote add NiclasPi https://github.com/NiclasPi/PyTorch-Wavelet-Toolbox.git
git fetch --all

then you can switch to the right fix-padding branch

@v0lta
Copy link
Owner

v0lta commented Jun 11, 2024

Thanks @cthoyt ! I think we are almost there now. The last thing I would like to do here ist get rid of _conv_transpose_dedilate and use group convolution instead.

@v0lta
Copy link
Owner

v0lta commented Jun 12, 2024

Hi Team, I removed the _conv_transpose_dedilate; the code is simpler and faster now. I also made the SWT module public by removing the underscore. From my point of view, this PR is ready for a merge.

@v0lta v0lta added this to the v0.1.9 milestone Jun 12, 2024
@v0lta v0lta merged commit 78abc5f into v0lta:main Jun 13, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants