Skip to content

Commit

Permalink
fix: get_hash function for engine caching (#3293)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Nov 15, 2024
1 parent 0841f34 commit afb1516
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
40 changes: 30 additions & 10 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import pickletools
import shutil
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch._inductor.codecache import sha256_hash
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import (
_SETTINGS_TO_BE_ENGINE_INVARIANT,
Expand Down Expand Up @@ -49,17 +48,38 @@ def get_hash(
Args:
gm (torch.fx.GraphModule): GraphModule to hash
input_specs (Sequence[Input]): input specs for the GraphModule
settings (CompilationSettings): compilation settings for the GraphModule
Returns:
str: hash value of the GraphModule
"""
# parameters are set to 0
with unset_fake_temporarily():
new_gm = copy.deepcopy(gm)
for name, param in new_gm.named_parameters():
param.data.zero_()

graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
def canonicalize_graph(graph: torch.fx.Graph) -> str:
"""Canonicalize the graph to a string for isomorphic graph comparison
Args:
graph (torch.fx.Graph): graph to canonicalize
Returns:
str: canonicalized graph string
"""
canonical_nodes = []
input_counter = 0

for node in graph.nodes:
if node.op == "placeholder":
canonical_nodes.append(f"placeholder_input_{input_counter}")
input_counter += 1
else:
canonical_nodes.append(f"{node.op}_{node.target}")

return " ".join(canonical_nodes)

graph_str = canonicalize_graph(gm.graph)
_LOGGER.debug(f"graph_str:\n {graph_str}")

graph_hash = sha256_hash(graph_str.encode())

input_spec_strs = [str(i) for i in input_specs]
with io.BytesIO() as stream:
Expand All @@ -75,7 +95,7 @@ def get_hash(
engine_specs_data = pickletools.optimize(engine_specs_data)
engine_specs_hash = sha256_hash(engine_specs_data)

hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash
hash_val: str = graph_hash + input_specs_hash + engine_specs_hash

return hash_val

Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
)
end.record()
torch.cuda.synchronize()
torch._dynamo.reset()
times.append(start.elapsed_time(end))
results.append(trt_gm(*inputs))

Expand Down Expand Up @@ -396,7 +395,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": engine_cache_dir,
"engine_cache_size": 1 << 30, # 1GB
"torch_executed_ops": {"torch.ops.aten.relu.default"},
},
)
results.append(compiled_model(*inputs)) # trigger the compilation
Expand Down Expand Up @@ -441,7 +439,6 @@ def test_torch_compile_with_custom_engine_cache(self):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
Expand All @@ -462,7 +459,6 @@ def test_torch_compile_with_custom_engine_cache(self):
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": custom_engine_cache,
"torch_executed_ops": {"torch.ops.aten.relu.default"},
},
)
results.append(compiled_model(*inputs)) # trigger the compilation
Expand Down

0 comments on commit afb1516

Please sign in to comment.