Skip to content

Commit

Permalink
Merge pull request tensorflow#56056 from meena-at-work:meenakshiv/tft…
Browse files Browse the repository at this point in the history
…rt-device-placement

PiperOrigin-RevId: 448974242
  • Loading branch information
tensorflower-gardener committed May 16, 2022
2 parents 18c208a + 192d3df commit 67529f3
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tensorflow/python/compiler/tensorrt/trt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def __init__(self,
self._calibration_input_fn = None

self._converted = False
self._device = None
self._build_called_once = False
self._calibrated = False

Expand Down Expand Up @@ -1240,6 +1241,17 @@ def convert(self, calibration_input_fn=None):
"""
assert not self._converted

# Creating an empty tensor to fetch queried device
device_requested = array_ops.zeros([]).device

if "gpu" not in device_requested.lower():
raise ValueError(f"Specified device is not a GPU: {device_requested}")

if "gpu:0" not in device_requested.lower():
self._device = device_requested
logging.info(f"Placing imported graph from "
f"`{self._input_saved_model_dir}` on device: {self._device}")

if (self._need_calibration and not calibration_input_fn):
raise ValueError("Should specify calibration_input_fn because INT8 "
"calibration is needed")
Expand All @@ -1251,8 +1263,21 @@ def convert(self, calibration_input_fn=None):
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
frozen_graph_def = frozen_func.graph.as_graph_def()

# Clear any prior device assignments
logging.info("Clearing prior device assignments in loaded saved model")
for node in frozen_graph_def.node:
node.device = ""

if self._device is None:
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_graph_def, graph=frozen_func.graph)
else:
with ops.Graph().as_default() as graph, ops.device(self._device):
importer.import_graph_def(frozen_graph_def, name="")
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)

# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
Expand Down
73 changes: 73 additions & 0 deletions tensorflow/python/compiler/tensorrt/trt_convert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import config
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import def_function
Expand Down Expand Up @@ -1066,6 +1067,78 @@ def testTrtGraphConverterV2_SaveWithOptions(self):
mock_save.save.assert_called_once_with(
mock.ANY, mock.ANY, mock.ANY, options=options)

@parameterized.named_parameters([
("NoDeviceAssignment", None),
("GPU1", "GPU:1"),
])
@test_util.run_v2_only
def testTrtGraphConverter_DevicePlacement(self, device_id):
"""Test case for trt_convert.TrtGraphConverter()."""

gpus = config.list_physical_devices("GPU")
if len(gpus) < 2:
self.skipTest("Expected at least 2 GPUs but found {} GPUs".format(
len(gpus)))

np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))

# Create a model and save it.
input_saved_model_dir = self.mkdtemp()
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})

converter = self._CreateConverterV2(
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)

converted_model = None
# Specify device on which converted model should be placed
with ops.device(device_id):
converted_model = converter.convert()

# Verify that TRT engine op has the correct device.
self._CheckTrtOps(converter._converted_func)

actual_device_id = self._GetUniqueTRTEngineOp(
converter._converted_graph_def).device

expected_device_id = None
if device_id is not None:
expected_device_id = device_id
else:
expected_device_id = "GPU:0"

self.assertTrue(expected_device_id.lower() in actual_device_id.lower())

del converter
gc.collect() # Force GC to destroy the TRT engine cache.

@test_util.run_v2_only
def testTrtGraphConverter_DevicePlacementOnCPU(self):
"""Test case for trt_convert.TrtGraphConverter()."""

np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))

# Create a model and save it.
input_saved_model_dir = self.mkdtemp()
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})

# Run TRT conversion.
converter = self._CreateConverterV2(
input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)

converted_model = None
# Specify device on which converted model should be placed
with self.assertRaisesRegex(ValueError, r"Specified device is not a GPU"):
with ops.device("CPU"):
converted_model = converter.convert()

del converter
gc.collect() # Force GC to destroy the TRT engine cache.

if __name__ == "__main__" and is_tensorrt_enabled():
test.main()

0 comments on commit 67529f3

Please sign in to comment.