diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 982d9d8ff3..0c5200905e 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -98,11 +98,6 @@ def bar2(a, b): # [5,2], [2,2] @requiresCUDA def test_nanogpt_block(): - # The estimated memory usage is not the same as actual peak memory usage on Hopper - if torch.cuda.get_device_capability() >= (9, 0): - pytest.skip( - f"the estimated memory usage is not the same as actual peak memory usage on {torch.cuda.get_device_name()}" - ) import thunder.tests.nanogpt_model as nanogpt_model config = nanogpt_model.GPTConfig(dropout=0) @@ -116,8 +111,9 @@ def test_nanogpt_block(): max_mem_fw = get_alloc_memory(fw_trace) max_mem_bw = get_alloc_memory(bw_trace) - result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace) - assert max_mem_fw[0] == result["fw_peak"] - assert sum(max_mem_fw[1].values()) == result["fw_current"] - assert max_mem_bw[0] == result["bw_peak"] - assert sum(max_mem_bw[1].values()) == result["bw_current"] + # Actual memory usage may vary depending on hardware and cuBLAS settings. + # We are checking the estimated memory against a fixed value for consistency. + assert max_mem_fw[0] == 381754368 + assert sum(max_mem_fw[1].values()) == 375462912 + assert max_mem_bw[0] == 437292032 + assert sum(max_mem_bw[1].values()) == 40934400