-
Notifications
You must be signed in to change notification settings - Fork 255
Adds spectral functions to Mesh TensorFlow #250
base: master
Are you sure you want to change the base?
Conversation
Added complex manipulation ops and unit tests for added ops
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it! |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
@googlebot I consent. |
An important note is that this PR in addition to the spectral ops adds complex support for the gradient (otherwise the gradient will not flow through the spectral ops) and complex manipulation operations if you need to navigate between complex-valued tensors and float-valued tensors. |
This PR adds spectral operations needed for the flowpm project in Mesh TensorFlow, which was the subject of this blogpost: https://blog.tensorflow.org/2020/03/simulating-universe-in-tensorflow.html .
These operations are useful for lots of applications including N-body simulations and MRI reconstructions. For now, we have only added the implementation of 3D FFTs.
The implementation is based on applying a series of 1D FFTs along the trailing dimensions of the input tensors, then using all2all communications and local transpose operations, to transpose the tensor until all 3 dimensions have been transformed.
The algorithm is illustrated here:

from https://www-user.tu-chemnitz.de/~potts/workgroup/pippig/paper/PFFT_SIAM_88588.pdf
2 things to note:
These two things could be avoided but require at least 2 more all2all operations to transpose and reshape the output array back to the original memory layout of the input array... this could be provided as a option to the user probably.
A minimal example for would be: