Skip to content

Commit

Permalink
Filter hopper in Ampere targets
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 25, 2024
1 parent cbf00e2 commit baf6783
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
dtypes_with_bfloat16,
is_cuda,
is_interpreter,
is_hopper,
is_hip,
is_hip_cdna,
is_hip_mi200,
Expand Down Expand Up @@ -195,7 +196,12 @@ def is_layout_applicable(layout) -> bool:
if layout in common_layouts:
return True
elif is_cuda():
return isinstance(layout, MmaLayout)
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
if not isinstance(mma_layout, MmaLayout):
return False
if mma_layout.version[0] >= 3 and not is_hopper():
return False
return True
elif is_hip():
target_arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in target_arch:
Expand Down Expand Up @@ -5300,9 +5306,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):

@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
if str(src_layout) == str(dst_layout):
pytest.skip()
Expand Down

0 comments on commit baf6783

Please sign in to comment.