From 192d3df9345bd838f5c4f163971b850220e76e91 Mon Sep 17 00:00:00 2001 From: Meenakshi Venkataraman Date: Tue, 26 Apr 2022 11:37:07 -0700 Subject: [PATCH] TFTRT: Respect device placement requested by user Previously TRT engines would always default to running on GPU 0. On a multi GPU system, this does not make the best use of the available resources. This commit adds the ability to specify the GPU on which the TRT engine should run. Signed-off-by: Meenakshi Venkataraman --- .../python/compiler/tensorrt/trt_convert.py | 29 ++++++- .../compiler/tensorrt/trt_convert_test.py | 75 +++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index ac1c18a8c5c552..6c606f2b5f2cf4 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -1131,6 +1131,7 @@ def __init__(self, self._calibration_input_fn = None self._converted = False + self._device = None self._build_called_once = False self._calibrated = False @@ -1240,6 +1241,17 @@ def convert(self, calibration_input_fn=None): """ assert not self._converted + # Creating an empty tensor to fetch queried device + device_requested = array_ops.zeros([]).device + + if "gpu" not in device_requested.lower(): + raise ValueError(f"Specified device is not a GPU: {device_requested}") + + if "gpu:0" not in device_requested.lower(): + self._device = device_requested + logging.info(f"Placing imported graph from " + f"`{self._input_saved_model_dir}` on device: {self._device}") + if (self._need_calibration and not calibration_input_fn): raise ValueError("Should specify calibration_input_fn because INT8 " "calibration is needed") @@ -1251,8 +1263,21 @@ def convert(self, calibration_input_fn=None): self._input_saved_model_tags) func = self._saved_model.signatures[self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) + frozen_graph_def = frozen_func.graph.as_graph_def() + + # Clear any prior device assignments + logging.info("Clearing prior device assignments in loaded saved model") + for node in frozen_graph_def.node: + node.device = "" + + if self._device is None: + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=frozen_graph_def, graph=frozen_func.graph) + else: + with ops.Graph().as_default() as graph, ops.device(self._device): + importer.import_graph_def(frozen_graph_def, name="") + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index e36d5acd31cb8a..3008d7a5f7b169 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -26,6 +26,7 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import config from tensorflow.core.protobuf import config_pb2 from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.eager import def_function @@ -1066,6 +1067,80 @@ def testTrtGraphConverterV2_SaveWithOptions(self): mock_save.save.assert_called_once_with( mock.ANY, mock.ANY, mock.ANY, options=options) + @parameterized.named_parameters([ + ("NoDeviceAssignment", None), + ("GPU1", "GPU:1"), + ]) + @test_util.run_v2_only + def testTrtGraphConverter_DevicePlacement(self, device_id): + """Test case for trt_convert.TrtGraphConverter().""" + + gpus = config.list_physical_devices('GPU') + if len(gpus) < 2: + self.skipTest('Expected at least 2 GPUs but found {} GPUs'.format( + len(gpus))) + + np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) + np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) + + # Create a model and save it. + input_saved_model_dir = self.mkdtemp() + root = self._GetModelForV2() + save.save(root, input_saved_model_dir, + {_SAVED_MODEL_SIGNATURE_KEY: root.run}) + + converter = self._CreateConverterV2( + input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32) + + converted_model = None + # Specify device on which converted model should be placed + with ops.device(device_id): + converted_model = converter.convert() + + # Verify that TRT engine op has the correct device. + self._CheckTrtOps(converter._converted_func) + + actual_device_id = self._GetUniqueTRTEngineOp( + converter._converted_graph_def).device + + expected_device_id = None + if device_id is not None: + expected_device_id = device_id + else: + expected_device_id = 'GPU:0' + + self.assertTrue(expected_device_id.lower() in actual_device_id.lower()) + + del converter + gc.collect() # Force GC to destroy the TRT engine cache. + + @test_util.run_v2_only + def testTrtGraphConverter_DevicePlacementOnCPU(self): + """Test case for trt_convert.TrtGraphConverter().""" + + np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) + np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) + + # Create a model and save it. + input_saved_model_dir = self.mkdtemp() + root = self._GetModelForV2() + save.save(root, input_saved_model_dir, + {_SAVED_MODEL_SIGNATURE_KEY: root.run}) + + # Run TRT conversion. + converter = self._CreateConverterV2( + input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32) + + converted_model = None + # Specify device on which converted model should be placed + with self.assertRaisesRegex( + ValueError, + r"Specified device is not a GPU"): + with ops.device('CPU'): + converted_model = converter.convert() + + del converter + gc.collect() # Force GC to destroy the TRT engine cache. if __name__ == "__main__" and is_tensorrt_enabled(): test.main()