Skip to content

Commit

Permalink
TFTRT: Respect device placement requested by user
Browse files Browse the repository at this point in the history
Previously TRT engines would always default to
running on GPU 0. On a multi GPU system, this
does not make the best use of the available
resources. This commit adds the ability to specify
the GPU on which the TRT engine should run.

Signed-off-by: Meenakshi Venkataraman <[email protected]>
  • Loading branch information
meena-at-work committed May 10, 2022
1 parent aefc36f commit 192d3df
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tensorflow/python/compiler/tensorrt/trt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def __init__(self,
self._calibration_input_fn = None

self._converted = False
self._device = None
self._build_called_once = False
self._calibrated = False

Expand Down Expand Up @@ -1240,6 +1241,17 @@ def convert(self, calibration_input_fn=None):
"""
assert not self._converted

# Creating an empty tensor to fetch queried device
device_requested = array_ops.zeros([]).device

if "gpu" not in device_requested.lower():
raise ValueError(f"Specified device is not a GPU: {device_requested}")

if "gpu:0" not in device_requested.lower():
self._device = device_requested
logging.info(f"Placing imported graph from "
f"`{self._input_saved_model_dir}` on device: {self._device}")

if (self._need_calibration and not calibration_input_fn):
raise ValueError("Should specify calibration_input_fn because INT8 "
"calibration is needed")
Expand All @@ -1251,8 +1263,21 @@ def convert(self, calibration_input_fn=None):
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
frozen_graph_def = frozen_func.graph.as_graph_def()

# Clear any prior device assignments
logging.info("Clearing prior device assignments in loaded saved model")
for node in frozen_graph_def.node:
node.device = ""

if self._device is None:
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_graph_def, graph=frozen_func.graph)
else:
with ops.Graph().as_default() as graph, ops.device(self._device):
importer.import_graph_def(frozen_graph_def, name="")
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)

# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
Expand Down
75 changes: 75 additions & 0 deletions tensorflow/python/compiler/tensorrt/trt_convert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import config
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import def_function
Expand Down Expand Up @@ -1066,6 +1067,80 @@ def testTrtGraphConverterV2_SaveWithOptions(self):
mock_save.save.assert_called_once_with(
mock.ANY, mock.ANY, mock.ANY, options=options)

@parameterized.named_parameters([
("NoDeviceAssignment", None),
("GPU1", "GPU:1"),
])
@test_util.run_v2_only
def testTrtGraphConverter_DevicePlacement(self, device_id):
"""Test case for trt_convert.TrtGraphConverter()."""

gpus = config.list_physical_devices('GPU')
if len(gpus) < 2:
self.skipTest('Expected at least 2 GPUs but found {} GPUs'.format(
len(gpus)))

np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))

# Create a model and save it.
input_saved_model_dir = self.mkdtemp()
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})

converter = self._CreateConverterV2(
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)

converted_model = None
# Specify device on which converted model should be placed
with ops.device(device_id):
converted_model = converter.convert()

# Verify that TRT engine op has the correct device.
self._CheckTrtOps(converter._converted_func)

actual_device_id = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).device

expected_device_id = None
if device_id is not None:
expected_device_id = device_id
else:
expected_device_id = 'GPU:0'

self.assertTrue(expected_device_id.lower() in actual_device_id.lower())

del converter
gc.collect() # Force GC to destroy the TRT engine cache.

@test_util.run_v2_only
def testTrtGraphConverter_DevicePlacementOnCPU(self):
"""Test case for trt_convert.TrtGraphConverter()."""

np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))

# Create a model and save it.
input_saved_model_dir = self.mkdtemp()
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})

# Run TRT conversion.
converter = self._CreateConverterV2(
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)

converted_model = None
# Specify device on which converted model should be placed
with self.assertRaisesRegex(
ValueError,
r"Specified device is not a GPU"):
with ops.device('CPU'):
converted_model = converter.convert()

del converter
gc.collect() # Force GC to destroy the TRT engine cache.

if __name__ == "__main__" and is_tensorrt_enabled():
test.main()

0 comments on commit 192d3df

Please sign in to comment.