Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Adds spectral functions to Mesh TensorFlow #250

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
83d0bbf
Adds spectral functions
Jan 30, 2020
f9b0dab
Merge branch 'master' into spectral
EiffL Aug 13, 2020
86d9991
added complex manipulation ops
zaccharieramzi Nov 25, 2020
87510bc
added complex manipulation ops tests
zaccharieramzi Nov 25, 2020
159027a
cleaned signal ops imports
zaccharieramzi Nov 25, 2020
6f85ab6
corrected complex ops test naming
zaccharieramzi Nov 25, 2020
0e2760d
added tests for signal ops
zaccharieramzi Nov 25, 2020
14dda78
corrected volume shape in signal ops tests
zaccharieramzi Nov 25, 2020
69b2d76
corrected shaping for signal ops tests
zaccharieramzi Nov 25, 2020
7151756
slightly complexified the signal ops test
zaccharieramzi Nov 25, 2020
2074197
refactored fft3d direct
zaccharieramzi Nov 25, 2020
afe7c63
refactored ifft3d
zaccharieramzi Nov 25, 2020
9ef47fc
corrected documentation for shape return of fft3d
zaccharieramzi Nov 25, 2020
782250e
changed transpose name
zaccharieramzi Nov 25, 2020
7ea6a90
simplified the implementation of the split complex op
zaccharieramzi Nov 26, 2020
bcd791e
corrected slicewise input list expected
zaccharieramzi Nov 26, 2020
ea9b55f
corrected name for the base op
zaccharieramzi Nov 26, 2020
ce9e850
corrected doc for fft3d function
zaccharieramzi Nov 26, 2020
f6a2dfe
corrected legacy typo in fft base class
zaccharieramzi Nov 26, 2020
497c257
corrected order of transpose in fft test to compare the output directly
zaccharieramzi Nov 26, 2020
28debd5
Proposes some further code compression to remove duplication of trans…
EiffL Nov 26, 2020
daab646
fix bug introduced by previous commit
EiffL Nov 26, 2020
3bcaef8
Merge pull request #2 from zaccharieramzi/spectral-tests
EiffL Nov 26, 2020
b6e9d08
Merge branch 'master' of https://github.com/tensorflow/mesh into tens…
EiffL Nov 26, 2020
b78533f
Merge branch 'tensorflow-master' into spectral
EiffL Nov 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions examples/fft_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Benchmark script for studying the scaling of distributed FFTs on Mesh Tensorflow
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import tensorflow.compat.v1 as tf
import mesh_tensorflow as mtf

from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import estimator as estimator_lib

# Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string(
"tpu", default=None,
help="The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", default=None,
help="[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", default=None,
help="[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")

tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")

tf.flags.DEFINE_integer("cube_size", 512, "Size of the 3D volume.")
tf.flags.DEFINE_integer("batch_size", 128,
"Mini-batch size for the training. Note that this "
"is the global batch size and not the per-shard batch.")

tf.flags.DEFINE_string("mesh_shape", "b1:32", "mesh shape")
tf.flags.DEFINE_string("layout", "nx:b1,tny:b1", "layout rules")

FLAGS = tf.flags.FLAGS

def benchmark_model(mesh):
"""
Initializes a 3D volume with random noise, and execute a forward FFT
"""
batch_dim = mtf.Dimension("batch", FLAGS.batch_size)

# Declares real space dimensions
x_dim = mtf.Dimension("nx", FLAGS.cube_size)
y_dim = mtf.Dimension("ny", FLAGS.cube_size)
z_dim = mtf.Dimension("nz", FLAGS.cube_size)

# Declares Fourier space dimensions
tx_dim = mtf.Dimension("tnx", FLAGS.cube_size)
ty_dim = mtf.Dimension("tny", FLAGS.cube_size)
tz_dim = mtf.Dimension("tnz", FLAGS.cube_size)

# Create field
field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim])

# Apply FFT
fft_field = mtf.signal.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim])

# Inverse FFT
rfield = mtf.cast(mtf.signal.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32)

# Compute errors
err = mtf.reduce_max(mtf.abs(field - rfield))
return err

def model_fn(features, labels, mode, params):
"""A model is called by TpuEstimator."""
del labels
del features

mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

ctx = params['context']
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
tf.logging.info('device_list = %s' % device_list,)

mesh_devices = [''] * mesh_shape.size
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)

graph = mtf.Graph()
mesh = mtf.Mesh(graph, "fft_mesh")

with mtf.utils.outside_all_rewrites():
err = benchmark_model(mesh)

lowering = mtf.Lowering(graph, {mesh: mesh_impl})

tf_err = tf.to_float(lowering.export_to_tf_tensor(err))

with mtf.utils.outside_all_rewrites():
return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)


def main(_):

tf.logging.set_verbosity(tf.logging.INFO)
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

# Resolve the TPU environment
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu,
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project
)

run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
save_checkpoints_steps=None, # Disable the default saver
save_checkpoints_secs=None, # Disable the default saver
log_step_count_steps=100,
save_summary_steps=100,
tpu_config=tpu_config.TPUConfig(
num_shards=mesh_shape.size,
iterations_per_loop=100,
num_cores_per_replica=1,
per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

model = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size)

def dummy_input_fn(params):
"""Dummy input function """
return tf.zeros(shape=[params['batch_size']], dtype=tf.float32), tf.zeros(shape=[params['batch_size']], dtype=tf.float32)

# Run evaluate loop for ever, we will be connecting to this process using a profiler
model.evaluate(input_fn=dummy_input_fn, steps=100000)

if __name__ == "__main__":
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
59 changes: 56 additions & 3 deletions mesh_tensorflow/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,8 +1594,8 @@ def to_string(self):
@property
def has_gradient(self):
return (
[t for t in self.inputs if t.dtype.is_floating] and
[t for t in self.outputs if t.dtype.is_floating])
[t for t in self.inputs if t.dtype.is_floating or t.dtype.is_complex] and
[t for t in self.outputs if t.dtype.is_floating or t.dtype.is_complex])

def gradient(self, unused_grad_ys):
raise NotImplementedError("Gradient not implemented")
Expand Down Expand Up @@ -5815,7 +5815,7 @@ def random_uniform(mesh, shape, **kwargs):


def random_normal(mesh, shape, **kwargs):
"""Random uniform.
"""Random normal.

Args:
mesh: a Mesh
Expand Down Expand Up @@ -6674,3 +6674,56 @@ def reduce_first(tensor, reduced_dim):
r = mtf_range(tensor.mesh, reduced_dim, dtype=tf.int32)
first_element_filter = cast(equal(r, 0), tensor.dtype)
return reduce_sum(tensor * first_element_filter, reduced_dim=reduced_dim)


def to_complex(x, complex_dim=None):
"""Gathers the real and imaginary of a tensor in a complex tensor

Args:
x: a float Tensor
complex_dim: a Dimension where both the real and imaginary parts of the
tensor are. Defaults to None, which corresponds to the last
dimension of the tensor.
Returns:
a Tensor, complex-valued
"""
if complex_dim is None:
complex_dim = x.shape[-1]
x_real, x_imag = split(x, complex_dim, 2)
x_real = cast(x_real, tf.complex64)
x_imag = cast(x_imag, tf.complex64)
x_complex = x_real + 1j * x_imag
return x_complex


def split_complex(x, complex_dim=None):
"""Splits a complex tensor into real and imaginary, concatenated

Args:
x: a float Tensor
complex_dim: a Dimension where you want the split to happen.
Defaults to None, which corresponds to the last dimension of the tensor.
Returns:
a Tensor, float-valued
"""
if complex_dim is None:
split_dim = x.shape.dims[-1]
split_axis = -1
else:
split_dim = complex_dim
split_axis = x.shape.index(complex_dim)
splittable_dims = [d for d in x.shape if d != split_dim]
def tf_fn(tf_input):
tf_real = tf.math.real(tf_input)
tf_imag = tf.math.imag(tf_input)
output = tf.concat([tf_real, tf_imag], axis=split_axis)
return output
output = slicewise(
tf_fn,
[x],
output_shape=x.shape.resize_dimension(split_dim.name, split_dim.size*2),
output_dtype=tf.float32,
splittable_dims=splittable_dims,
name='split_complex',
)
return output
45 changes: 45 additions & 0 deletions mesh_tensorflow/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,51 @@ def x_squared_plus_x(x):
self.evaluate(expected_dx))


class ComplexManipulationTest(tf.test.TestCase):
def setUp(self):
super(ComplexManipulationTest, self).setUp()
self.graph = mtf.Graph()
self.mesh = mtf.Mesh(self.graph, "my_mesh")

def testToComplex(self):
tensor = tf.random.normal([1, 10, 4])
mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)]
tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape)
outputs = mtf.to_complex(tensor_mesh)
assert outputs.dtype == tf.complex64
assert len(outputs.shape) == 3
assert outputs.shape[-1].size == 2
assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
outputs_tf = lowering.export_to_tf_tensor(outputs)
self.assertAllEqual(
outputs_tf,
tf.complex(tensor[..., 0:2], tensor[..., 2:4]),
)

def testSplitComplex(self):
tensor = tf.complex(
tf.random.normal([1, 10, 2]),
tf.random.normal([1, 10, 2]),
)
mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)]
tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape)
outputs = mtf.split_complex(tensor_mesh)
assert outputs.dtype == tf.float32
assert len(outputs.shape) == 3
assert outputs.shape[-1].size == 4
assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
outputs_tf = lowering.export_to_tf_tensor(outputs)
self.assertAllEqual(
outputs_tf,
tf.concat([tf.math.real(tensor), tf.math.imag(tensor)], axis=-1),
)

if __name__ == "__main__":
tf.disable_v2_behavior()
tf.enable_eager_execution()
Expand Down
2 changes: 1 addition & 1 deletion mesh_tensorflow/ops_with_redefined_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mesh_tensorflow.ops import mtf_pow as pow # pylint: disable=redefined-builtin,unused-import
from mesh_tensorflow.ops import mtf_range as range # pylint: disable=redefined-builtin,unused-import
from mesh_tensorflow.ops import mtf_slice as slice # pylint: disable=redefined-builtin,unused-import

import mesh_tensorflow.signal_ops as signal


# TODO(trandustin): Seal module.
Expand Down
Loading