From afb1516e7cdf011e24b5e573e7afb92c3c4c0fdc Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Fri, 15 Nov 2024 12:06:17 -0800 Subject: [PATCH] fix: get_hash function for engine caching (#3293) --- py/torch_tensorrt/dynamo/_engine_cache.py | 40 +++++++++++++++------ tests/py/dynamo/models/test_engine_cache.py | 4 --- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index f166b489cb..7835c419d0 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -6,11 +6,10 @@ import pickletools import shutil from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from torch._inductor.codecache import FxGraphCachePickler, sha256_hash -from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch._inductor.codecache import sha256_hash from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import ( _SETTINGS_TO_BE_ENGINE_INVARIANT, @@ -49,17 +48,38 @@ def get_hash( Args: gm (torch.fx.GraphModule): GraphModule to hash + input_specs (Sequence[Input]): input specs for the GraphModule + settings (CompilationSettings): compilation settings for the GraphModule Returns: str: hash value of the GraphModule """ - # parameters are set to 0 - with unset_fake_temporarily(): - new_gm = copy.deepcopy(gm) - for name, param in new_gm.named_parameters(): - param.data.zero_() - graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm)) + def canonicalize_graph(graph: torch.fx.Graph) -> str: + """Canonicalize the graph to a string for isomorphic graph comparison + + Args: + graph (torch.fx.Graph): graph to canonicalize + + Returns: + str: canonicalized graph string + """ + canonical_nodes = [] + input_counter = 0 + + for node in graph.nodes: + if node.op == "placeholder": + canonical_nodes.append(f"placeholder_input_{input_counter}") + input_counter += 1 + else: + canonical_nodes.append(f"{node.op}_{node.target}") + + return " ".join(canonical_nodes) + + graph_str = canonicalize_graph(gm.graph) + _LOGGER.debug(f"graph_str:\n {graph_str}") + + graph_hash = sha256_hash(graph_str.encode()) input_spec_strs = [str(i) for i in input_specs] with io.BytesIO() as stream: @@ -75,7 +95,7 @@ def get_hash( engine_specs_data = pickletools.optimize(engine_specs_data) engine_specs_hash = sha256_hash(engine_specs_data) - hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash + hash_val: str = graph_hash + input_specs_hash + engine_specs_hash return hash_val diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 367f68c1f6..5ceea5e381 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -231,7 +231,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): ) end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) @@ -396,7 +395,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, "engine_cache_size": 1 << 30, # 1GB - "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) results.append(compiled_model(*inputs)) # trigger the compilation @@ -441,7 +439,6 @@ def test_torch_compile_with_custom_engine_cache(self): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement if i == 0: cache_built_engines = False reuse_cached_engines = False @@ -462,7 +459,6 @@ def test_torch_compile_with_custom_engine_cache(self): "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, - "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) results.append(compiled_model(*inputs)) # trigger the compilation