From fb67011900be823a5bc25ce18d0c4a3c253f2a61 Mon Sep 17 00:00:00 2001 From: Hui Kang Date: Thu, 20 Apr 2023 22:29:26 -0700 Subject: [PATCH] add SOK optimizer redefineation switch --- .../experiment/__init__.py | 20 ++- .../run_function_test_multi_process.sh | 4 + ...stributed_dynamic_test_without_redefine.py | 158 ++++++++++++++++++ 3 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/lookup/lookup_sparse_distributed_dynamic_test_without_redefine.py diff --git a/sparse_operation_kit/sparse_operation_kit/experiment/__init__.py b/sparse_operation_kit/sparse_operation_kit/experiment/__init__.py index ea090afff3..c395010767 100644 --- a/sparse_operation_kit/sparse_operation_kit/experiment/__init__.py +++ b/sparse_operation_kit/sparse_operation_kit/experiment/__init__.py @@ -66,14 +66,8 @@ # a specific code path for dl framework tf2.11.0 import tensorflow -try: - if tensorflow.keras.optimizers.legacy.Optimizer.__name__ == "OptimizerV2": - tensorflow.keras.optimizers = tensorflow.keras.optimizers.legacy -except: - pass - -def init(comm_tool="horovod"): +def init(comm_tool="horovod", use_legacy_optimizer=True): """ Abbreviated as ``sok.experiment.init``. @@ -108,11 +102,21 @@ def init(comm_tool="horovod"): ---------- comm_tool: string a string to specify which communication tool to use. Default value is "horovod". - + use_legacy_optimizer: bool + From tensorflow 2.11.0 , keras default optimizer is optimizer experimental. SOK won't support it in future, so if you switch use_legacy_optimizer to True, + SOK will redefine tensorflow.keras.optimizers to tensorflow.keras.optimizers.legacy(tf.keras.optimizers.optimizer_v2). + Default value is True, if you want to use new optimizer in the other part in your code , and only use legacy optimizer in SOK, please set to False Returns ------- None """ + if use_legacy_optimizer: + try: + if tensorflow.keras.optimizers.legacy.Optimizer.__name__ == "OptimizerV2": + tensorflow.keras.optimizers = tensorflow.keras.optimizers.legacy + except: + pass + set_comm_tool(comm_tool) status = raw_ops.set_default_allocator() print("[SOK INFO] Initialize finished, communication tool: " + comm_tool) diff --git a/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/run_function_test_multi_process.sh b/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/run_function_test_multi_process.sh index 0c246a2aa2..65bf9878a2 100755 --- a/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/run_function_test_multi_process.sh +++ b/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/run_function_test_multi_process.sh @@ -19,3 +19,7 @@ horovodrun -np ${task_num} python lookup_sparse_distributed_test.py horovodrun -np ${task_num} python lookup_sparse_distributed_dynamic_test.py horovodrun -np ${task_num} python lookup_sparse_localized_test.py horovodrun -np ${task_num} python lookup_sparse_localized_dynamic_test.py + +if [[ ${tf_version} -eq 2 ]];then +horovodrun -np ${task_num} python lookup_sparse_distributed_dynamic_test_without_redefine.py +fi diff --git a/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/lookup/lookup_sparse_distributed_dynamic_test_without_redefine.py b/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/lookup/lookup_sparse_distributed_dynamic_test_without_redefine.py new file mode 100644 index 0000000000..c30a51a091 --- /dev/null +++ b/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/lookup/lookup_sparse_distributed_dynamic_test_without_redefine.py @@ -0,0 +1,158 @@ +""" + Copyright (c) 2022, NVIDIA CORPORATION. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import time +import numpy as np +import tensorflow as tf +import horovod.tensorflow as hvd + +from sparse_operation_kit import experiment as sok + + +if __name__ == "__main__": + hvd.init() + gpus = tf.config.experimental.list_physical_devices("GPU") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU") + sok.init(use_legacy_optimizer=False) + + rows = [8192 * 10, 8192] + cols = [128, 4] + hotness = [10, 3] + combiners = ["sum", "sum"] + batch_size = 8192 + iters = 100 + initial_vals = [13, 17] + + # sok variables + sok_vars = [ + sok.DynamicVariable(dimension=cols[i], initializer=str(initial_vals[i])) + for i in range(len(cols)) + ] + local_indices = [] + for row in rows: + local_size = row // hvd.size() + if hvd.rank() < row % hvd.size(): + local_size += 1 + indices = np.arange(local_size) * hvd.size() + hvd.rank() + indices = tf.convert_to_tensor(indices, dtype=tf.int64) + local_indices.append(indices) + + # indices + total_indices = [] + for i in range(len(rows)): + offsets = np.random.randint(1, hotness[i] + 1, iters * batch_size) + offsets = tf.convert_to_tensor(offsets, dtype=tf.int64) + offsets = hvd.broadcast(offsets, root_rank=0) + values = np.random.randint(0, rows[i], tf.reduce_sum(offsets)) + values = tf.convert_to_tensor(values, dtype=tf.int64) + values = hvd.broadcast(values, root_rank=0) + total_indices.append(tf.RaggedTensor.from_row_lengths(values, offsets)) + left = batch_size // hvd.size() * hvd.rank() + right = batch_size // hvd.size() * (hvd.rank() + 1) + + # initialize optimizer + if tf.keras.optimizers.legacy.Optimizer.__name__ == "OptimizerV2": + optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=1.0) + else: + optimizer = tf.keras.optimizers.SGD(learning_rate=1.0) + + sok_optimizer = sok.OptimizerWrapper(optimizer) + tf_vars = [ + tf.Variable(tf.constant(initial_vals[i], shape=[rows[i], cols[i]], dtype=tf.float32)) + for i in range(len(rows)) + ] + + @tf.function + def step(params, indices): + with tf.GradientTape() as tape: + embeddings = sok.lookup_sparse(params, indices, combiners=combiners) + loss = 0 + for i in range(len(embeddings)): + loss = loss + tf.reduce_sum(embeddings[i]) + grads = tape.gradient(loss, params) + sok_optimizer.apply_gradients(zip(grads, params)) + loss = hvd.allreduce(loss, op=hvd.Sum) + return loss + + # Do training + loss1 = [] + ts = [] + t = time.time() + for i in range(iters): + ts.append(time.time() - t) + t = time.time() + indices = [] + for j in range(len(total_indices)): + indices.append(total_indices[j][i * batch_size + left : i * batch_size + right]) + loss = step(sok_vars, indices) + loss1.append(loss) + print("-" * 30 + "iteration %d" % i + "-" * 30) + print("loss:", loss) + out1 = [] + for i in range(len(sok_vars)): + out1.append(tf.nn.embedding_lookup(sok_vars[i], local_indices[i])) + + @tf.function + def step2(params, indices): + with tf.GradientTape() as tape: + loss = 0 + for i in range(len(params)): + embedding = tf.nn.embedding_lookup_sparse( + params[i], indices[i], None, combiner=combiners[i] + ) + loss = loss + tf.reduce_sum(embedding) + grads = tape.gradient(loss, params) + grads = [hvd.allreduce(grad, op=hvd.Sum) for grad in grads] + optimizer.apply_gradients(zip(grads, params)) + loss = hvd.allreduce(loss, op=hvd.Sum) + return loss + + loss2 = [] + for i in range(iters): + indices = [] + for j in range(len(total_indices)): + indices.append( + total_indices[j][i * batch_size + left : i * batch_size + right].to_sparse() + ) + loss = step2(tf_vars, indices) + loss2.append(loss) + print("-" * 30 + "iteration %d" % i + "-" * 30) + print("tf loss:", loss) + out2 = [] + for i, v in enumerate(tf_vars): + out2.append(tf.nn.embedding_lookup(v, local_indices[i])) + + # Check results + diff = 0 + for i in range(len(out1)): + length = out1[i] ** 2 + out2[i] ** 2 + 1e-8 + diff = diff + tf.reduce_sum((out1[i] - out2[i]) ** 2 / length) + print("[SOK INFO] diff:", diff) + assert diff < 1e-6 + + diff = 0 + for i in range(iters): + length = loss1[i] ** 2 + loss2[i] ** 2 + 1e-8 + diff = diff + (loss1[i] - loss2[i]) ** 2 / length + print("[SOK INFO] loss diff:", diff) + assert diff < 1e-6 + + print("[SOK INFO] lookup_sparse distributed with dynamic variable test passed") + ts = ts[5:] + print("[SOK INFO] Average time: %f ms/iteration" % (sum(ts) / len(ts) * 1000))