Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make preprocessing and postprocessing consistent accross transforms #93

Merged
merged 36 commits into from
Jul 1, 2024

Conversation

felixblanke
Copy link
Collaborator

This addresses #92.

For all discrete transforms, the preprocessing and postprocessing of coefficients and tensors is very similar (i.e. folding and swapping of axes, adding batch dims, etc.). This PR moves this functionality into shared functions that use _map_result.

Also the check for consistent devices and dtypes between coefficient tensors is moved into a function in _utils.

Last, as it was possible to add it with a few lines of code, I added the $n$-dimensional fully separable transforms (fswavedecn, fswaverecn). If this is not wanted, I can revert their addition.

Further, I did some minor refactorings along the way.

@felixblanke felixblanke added enhancement New feature or request invalid This doesn't seem right labels Jun 24, 2024
@v0lta
Copy link
Owner

v0lta commented Jun 25, 2024

I did not add n-dimensional separable transforms on purpose, because I was thinking people will ask for these in all other cases, too, where these are trickier to deliver.

src/ptwt/_util.py Outdated Show resolved Hide resolved
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
else:
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
coeffs = _map_result(coeffs, undo_swap_fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really advise against all of these _map_result calls - have one function that does the processing that can be reused, then just do list comprehensions for all successive function calls.

coeffs = _map_result(coeffs, lambda x: x.squeeze(0))

will always be less readable and understandable than

coeffs = _map_result(coeffs)
coeffs = [coeff.squeeze(0) for coeff in coeffs]

when _map_result has lots of hidden functionality

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The snippet

coeffs = _map_result(coeffs, lambda x: x.squeeze(0))

applies the function x.squeeze(0) to all tensors in coeffs. In the 1d case (where coeffs is of type list[Tensor]) this is equivalent to the list comprehension, as you said. However, coeffs might also be

  • (Tensor, dict[str, Tensor], ...)
  • (Tensor, (Tensor, Tensor, Tensor), ...)

So using _map_result allows to write the function once for all possible coefficient types. Would it perhaps help to rename _map_result or add documentation for it?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name should have something to do with tree and map. The new _apply_to_tensor_elems name really hides how general the concept is. I would take a page from https://jax.readthedocs.io/en/latest/_autosummary/jax.tree.map.html#jax.tree.map and also use their type hinting. The concept does not exist in torch, but I think it makes sense here, since we save on a lot of boilerplate-code. Perhaps we should include the link and explain what's going on?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an interesting intro discussing the pytree processing philosophy: https://jax.readthedocs.io/en/latest/working-with-pytrees.html .

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think @cthoyt has a point since the tree-map concept is not very popular.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree when comparing

coeffs = _map_result(coeffs, partial(torch.squeeze, 0))

to

coeffs = [coeff.squeeze(0) for coeff in coeffs]

The list wins, but what if it's a nested structure?

@felixblanke
Copy link
Collaborator Author

@v0lta I made the n-dim transform private. Does that work?

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Yes, that works. However, we won't be able to support n-dimensional transforms across the board because PyTorch does not provide the interfaces we would need to do that. Padding, for example, works only up until 3D ( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html ). We have the same problem with isotropic convolution. So, I think we should communicate that nd-transforms are out of the scope of this project.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

In general, I am a big fan of this full request! Thanks at @felixblanke I am going to clean up the docs for _map_result and commit here.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Our coeff_tree_map is not a general tree map, but it does not have to be since we know the approximation tensor will always be the first entry. I ran the not-slow tests. Everything checked out. The code is cleaner now. If everyone is on board, I would be ready to merge.

@felixblanke
Copy link
Collaborator Author

I think we so far only refer to the Packet data structure as a tree. Maybe we can add a link to the JAX discussion as a reference?

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

I am not sure if users need to know. I think this is more for us here internally. Unlike Jax's tree map, ours is coefficient-specific, hence the proposed coeff_tree_map function name.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Actually, never mind. The user argument does not matter since it's a private function. If you think it helps to understand the idea, please add a link. I think it might help potential future contributors.

@v0lta v0lta self-assigned this Jul 1, 2024
@v0lta
Copy link
Owner

v0lta commented Jul 1, 2024

Okay, I think we are ready to merge. @felixblanke @cthoyt is everyone on board?

@v0lta
Copy link
Owner

v0lta commented Jul 1, 2024

okay let's merge!

@v0lta v0lta merged commit 85b898a into main Jul 1, 2024
7 checks passed
@v0lta v0lta deleted the fix/keep-ndims-Nd branch July 1, 2024 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request invalid This doesn't seem right
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants