Skip to content

Commit

Permalink
enable nvfuser matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Apr 16, 2024
1 parent b4295cd commit 47eb3dc
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,26 @@ def get_num_fusions(cfn):
assert get_num_fusions(cfn_without_fusion) == 0

nvfuserex.set_fuel(thunder.extend.FUEL_LEVEL.UNLIMITED)


@instantiate(dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,))
def test_matmul(executor, device: str, dtype: dtypes.dtype):
m, n, k = 128, 64, 32
torch_dtype = ltorch.to_torch_dtype(dtype)
a = torch.randn((m, k), dtype=torch_dtype, device=device)
b = torch.randn((k, n), dtype=torch_dtype, device=device)

def fn(a , b):
return a.matmul(b);

compiled_func = thunder.compile(
fn,
executors_list=executor.executors_list(),
nv_enable_matmul=True,
)
out = compiled_func(a, b)

traces = thunder.last_traces(compiled_func)
fusions = examine.get_fusions(traces[-1])
assert len(fusions) == 1
assert torch.allclose(out, torch.matmul(a, b))

0 comments on commit 47eb3dc

Please sign in to comment.