Skip to content

Commit

Permalink
Fix estimated memory test (#1316)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Oct 17, 2024
1 parent eb771b0 commit a1c2cb8
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ def bar2(a, b): # [5,2], [2,2]

@requiresCUDA
def test_nanogpt_block():
# The estimated memory usage is not the same as actual peak memory usage on Hopper
if torch.cuda.get_device_capability() >= (9, 0):
pytest.skip(
f"the estimated memory usage is not the same as actual peak memory usage on {torch.cuda.get_device_name()}"
)
import thunder.tests.nanogpt_model as nanogpt_model

config = nanogpt_model.GPTConfig(dropout=0)
Expand All @@ -116,8 +111,9 @@ def test_nanogpt_block():
max_mem_fw = get_alloc_memory(fw_trace)
max_mem_bw = get_alloc_memory(bw_trace)

result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace)
assert max_mem_fw[0] == result["fw_peak"]
assert sum(max_mem_fw[1].values()) == result["fw_current"]
assert max_mem_bw[0] == result["bw_peak"]
assert sum(max_mem_bw[1].values()) == result["bw_current"]
# Actual memory usage may vary depending on hardware and cuBLAS settings.
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 381754368
assert sum(max_mem_fw[1].values()) == 375462912
assert max_mem_bw[0] == 437292032
assert sum(max_mem_bw[1].values()) == 40934400

0 comments on commit a1c2cb8

Please sign in to comment.