Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Adds spectral functions to Mesh TensorFlow #250

Open
wants to merge 25 commits into
base: master
Choose a base branch
from

Conversation

EiffL
Copy link

@EiffL EiffL commented Nov 26, 2020

This PR adds spectral operations needed for the flowpm project in Mesh TensorFlow, which was the subject of this blogpost: https://blog.tensorflow.org/2020/03/simulating-universe-in-tensorflow.html .

These operations are useful for lots of applications including N-body simulations and MRI reconstructions. For now, we have only added the implementation of 3D FFTs.

The implementation is based on applying a series of 1D FFTs along the trailing dimensions of the input tensors, then using all2all communications and local transpose operations, to transpose the tensor until all 3 dimensions have been transformed.

The algorithm is illustrated here:
image
from https://www-user.tu-chemnitz.de/~potts/workgroup/pippig/paper/PFFT_SIAM_88588.pdf

2 things to note:

  • The user needs to specify the Fourier dimensions, which will be referenced in the mesh layout, to make sure the output of the FFT remains distributed.
  • The output of the FFT is transposed, to save on extra all2all operations.
    These two things could be avoided but require at least 2 more all2all operations to transpose and reshape the output array back to the original memory layout of the input array... this could be provided as a option to the user probably.

A minimal example for would be:

batch_dim = mtf.Dimension("batch", batch_size)
x_dim = mtf.Dimension("nx", nc)
y_dim = mtf.Dimension("ny", nc)
z_dim = mtf.Dimension("nz", nc)

kx_dim = mtf.Dimension("kx", nc)
ky_dim = mtf.Dimension("ky", nc)
kz_dim = mtf.Dimension("kz", nc)

# Create field
field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim])
# Apply FFT
fft_field = mtf.signal.fft3d(mtf.cast(field, tf.complex64), [kx_dim, ky_dim, kz_dim])
# fft_field as shape: [batch, ky, kz, kx]
# Inverse FFT
recfield = mtf.cast(mtf.signal.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32)

# The  following layout would be appropriate
mesh_shape = [("row", nblockx), ("col", nblocky)]
layout_rules = [("nx", "row"), ("ny", "col"),
      ("ky", "row"), ("kz", "col"),]   # Note  that the Fourier dimensions are aslo split differently

@google-cla
Copy link

google-cla bot commented Nov 26, 2020

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added the cla: no label Nov 26, 2020
@EiffL
Copy link
Author

EiffL commented Nov 26, 2020

@googlebot I signed it!

@google-cla
Copy link

google-cla bot commented Nov 26, 2020

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@zaccharieramzi
Copy link

@googlebot I consent.

@google-cla google-cla bot added cla: yes and removed cla: no labels Nov 26, 2020
@zaccharieramzi
Copy link

An important note is that this PR in addition to the spectral ops adds complex support for the gradient (otherwise the gradient will not flow through the spectral ops) and complex manipulation operations if you need to navigate between complex-valued tensors and float-valued tensors.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants