Skip to content

Commit

Permalink
pre-commit: running and fixing...
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Apr 16, 2024
1 parent 09d9ab8 commit 96f1bd8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
3 changes: 3 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,7 @@ def var_mean(

register_supported(PrimIDs.VAR_MEAN, var_mean, _var_mean_check)


def _matmul_check(
a: TensorProxy,
b: TensorProxy,
Expand All @@ -1957,6 +1958,7 @@ def _matmul_check(
enable_matmul = False
return enable_matmul and is_supported_tensor(a) and is_supported_tensor(b)


def matmul(
a: TensorProxy,
b: TensorProxy,
Expand All @@ -1968,6 +1970,7 @@ def matmul(
nvb = getnv(b, fd, lc_to_nv_map)
return fd.ops.matmul(nva, nvb)


register_supported(PrimIDs.MATMUL, matmul, _matmul_check)


Expand Down
10 changes: 6 additions & 4 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,15 +852,17 @@ def get_num_fusions(cfn):
nvfuserex.set_fuel(thunder.extend.FUEL_LEVEL.UNLIMITED)


@instantiate(dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,))
@instantiate(
dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)
)
def test_matmul(executor, device: str, dtype: dtypes.dtype):
m, n, k = 128, 64, 32
torch_dtype = ltorch.to_torch_dtype(dtype)
a = torch.randn((m, k), dtype=torch_dtype, device=device)
b = torch.randn((k, n), dtype=torch_dtype, device=device)

def fn(a , b):
return a.matmul(b);
def fn(a, b):
return a.matmul(b)

compiled_func = thunder.compile(
fn,
Expand All @@ -872,4 +874,4 @@ def fn(a , b):
traces = thunder.last_traces(compiled_func)
fusions = examine.get_fusions(traces[-1])
assert len(fusions) == 1
assert torch.allclose(out, torch.matmul(a, b))
assert torch.allclose(out, torch.matmul(a, b))

0 comments on commit 96f1bd8

Please sign in to comment.