Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Arm SME/SVE2 Optimization pass #109

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

danikhan632
Copy link
Contributor

@danikhan632 danikhan632 commented Mar 5, 2024

Since this is an optimization pass to existing ttshare output, I decided to make it its own binary, currently gets past optimization phase and fails on _ttsharedir_to_llir which is to be expected since it needs different mlir-opt flags. These flags have also just been recently updated.

Also trying to introduce bf16/f16 support as well as make the current optimization passes only apply to hardware that can support it.

There are more plans for optimization than just the tile and outerproduct approach as seen here but the current build does produce valid MLIR. Based this off the example shown here.

As of now only SVE2 can tested on real hardware which I don't have access to. SME will have to be emulated.
Not yet anywhere ready in a state to be merged but feedback would be appreciated.

Instructions to build

Same as normal however to see the optimized MLIR,
Usage

#dependent on cmake/python/arch details
export TRITON_SME_PATH="$(pwd)/python/build/cmake.linux-aarch64-cpython-3.11/third_party/triton_shared/tools/triton-sme-opt/triton-sme-opt"
 cd ./third_party/triton_shared/python/examples
rm -rf ~/.triton/cache
python3 test_matmul.py

this is should cause the test to not compile and fail but the optimized MLIR should be printed in blue test
to turn this off just set

export TRITON_SME_PATH=""

Below is the optimized MLIR produced from test_matmul.py

#map = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map1 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map2 = affine_map<(d0, d1) -> (d0, 0, d1)>
#map3 = affine_map<(d0, d1) -> (0, d1, d0)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map6 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %cst = arith.constant dense<false> : vector<1x[4]xi1>
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_1 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst_0 : f32) outs(%alloc : memref<32x64xf32>)
    %0 = bufferization.to_tensor %alloc : memref<32x64xf32>
    %1 = arith.addi %arg3, %c31_i32 : i32
    %2 = arith.divsi %1, %c32_i32 : i32
    %3 = arith.addi %arg4, %c63_i32 : i32
    %4 = arith.divsi %3, %c64_i32 : i32
    %5 = arith.muli %4, %c8_i32 : i32
    %6 = arith.divsi %arg12, %5 : i32
    %7 = arith.muli %6, %c8_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c8_i32 : i32
    %10 = arith.remsi %arg12, %9 : i32
    %11 = arith.addi %7, %10 : i32
    %12 = arith.remsi %arg12, %5 : i32
    %13 = arith.divsi %12, %9 : i32
    %14 = arith.muli %11, %c32_i32 : i32
    %15 = arith.index_cast %14 : i32 to index
    %16 = arith.muli %13, %c64_i32 : i32
    %17 = arith.index_cast %16 : i32 to index
    %18 = arith.index_cast %arg3 : i32 to index
    %19 = arith.index_cast %arg6 : i32 to index
    %20 = arith.muli %15, %19 : index
    %21 = arith.muli %18, %19 : index
    %22 = arith.index_cast %arg7 : i32 to index
    %23 = arith.index_cast %arg4 : i32 to index
    %24 = arith.addi %arg5, %c15_i32 : i32
    %25 = arith.divsi %24, %c16_i32 : i32
    %26 = arith.muli %arg7, %c16_i32 : i32
    %27 = arith.index_cast %26 : i32 to index
    %28:3 = scf.for %arg15 = %c0_i32 to %25 step %c1_i32 iter_args(%arg16 = %0, %arg17 = %20, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %41 = bufferization.to_memref %arg16 : memref<32x64xf32>
      %42 = arith.addi %arg18, %17 : index
      %43 = arith.remsi %42, %23 : index
      %44 = arith.subi %42, %43 : index
      %45 = arith.addi %43, %c64 : index
      %46 = arith.minsi %45, %23 : index
      %47 = arith.subi %46, %43 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg1 to offset: [%42], sizes: [%c16, %47], strides: [%22, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %48 = arith.subi %c64, %47 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg1 to offset: [%44], sizes: [%c16, %48], strides: [%22, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %49 = arith.remsi %arg17, %19 : index
      %50 = arith.addi %21, %49 : index
      %51 = arith.subi %50, %arg17 : index
      %52 = arith.divsi %51, %19 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%52, %c16], strides: [%19, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %53 = arith.subi %c32, %52 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg0 to offset: [%49], sizes: [%53, %c16], strides: [%19, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %54 = arith.muli %arg15, %c16_i32 : i32
      %55 = arith.subi %arg5, %54 : i32
      %56 = arith.index_cast %55 : i32 to index
      %57 = arith.minsi %56, %c16 : index
      %alloc_8 = memref.alloc() : memref<32x16xf16>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst_1 : f16) outs(%alloc_8 : memref<32x16xf16>)
      }
      %59 = arith.minsi %52, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_9 = memref.subview %reinterpret_cast_6[0, 0] [%59, %57] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_7[0, 0] [%60, %57] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %57] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %subview_12 = memref.subview %alloc_8[%59, 0] [%60, %57] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_13 = memref.alloc() : memref<16x64xf16>
      %61 = arith.cmpi slt, %57, %c16 : index
      scf.if %61 {
        linalg.fill ins(%cst_1 : f16) outs(%alloc_13 : memref<16x64xf16>)
      }
      %62 = arith.minsi %47, %c64 : index
      %63 = arith.subi %c64, %62 : index
      %subview_14 = memref.subview %reinterpret_cast_4[0, 0] [%57, %62] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_15 = memref.subview %reinterpret_cast_5[0, 0] [%57, %63] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_16 = memref.subview %alloc_13[0, 0] [%57, %62] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %subview_17 = memref.subview %alloc_13[0, %62] [%57, %63] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %subview_14, %subview_16 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %subview_15, %subview_17 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %64 = vector.vscale
      %65 = arith.muli %64, %c4 : index
      %66 = arith.muli %64, %c4 : index
      %67 = scf.for %arg19 = %c0 to %c32 step %65 iter_args(%arg20 = %0) -> (tensor<32x64xf32>) {
        %72 = scf.for %arg21 = %c0 to %c64 step %66 iter_args(%arg22 = %arg20) -> (tensor<32x64xf32>) {
          %73 = scf.for %arg23 = %c0 to %c16 step %c1 iter_args(%arg24 = %arg22) -> (tensor<32x64xf32>) {
            %74 = bufferization.to_memref %arg24 : memref<32x64xf32>
            %75 = bufferization.to_memref %arg24 : memref<32x64xf32>
            %76 = affine.min #map(%arg19, %65)
            %77 = affine.min #map1(%arg21, %66)
            %78 = affine.min #map(%arg19, %65)
            %79 = affine.min #map1(%arg21, %66)
            %subview_19 = memref.subview %alloc_8[%arg19, %arg23] [%76, 1] [1, 1] : memref<32x16xf16> to memref<?x1xf16, strided<[16, 1], offset: ?>>
            %80 = bufferization.to_tensor %subview_19 : memref<?x1xf16, strided<[16, 1], offset: ?>>
            %subview_20 = memref.subview %alloc_13[%arg23, %arg21] [1, %77] [1, 1] : memref<16x64xf16> to memref<1x?xf16, strided<[64, 1], offset: ?>>
            %81 = bufferization.to_tensor %subview_20 : memref<1x?xf16, strided<[64, 1], offset: ?>>
            %subview_21 = memref.subview %75[%arg19, %arg21] [%78, %79] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %82 = bufferization.to_tensor %subview_21 : memref<?x?xf32, strided<[64, 1], offset: ?>>
            %83 = vector.create_mask %76, %c1 : vector<[4]x1xi1>
            %84 = vector.transfer_read %80[%c0, %c0], %cst_1, %83 {in_bounds = [true, true, true], permutation_map = #map2} : tensor<?x1xf16>, vector<[4]x[4]x1xf16>
            %85 = vector.create_mask %77 : vector<[4]xi1>
            %86 = vector.insert %85, %cst [0] : vector<[4]xi1> into vector<1x[4]xi1>
            %87 = vector.transfer_read %81[%c0, %c0], %cst_1, %86 {in_bounds = [true, true, true], permutation_map = #map3} : tensor<1x?xf16>, vector<[4]x[4]x1xf16>
            %88 = vector.create_mask %76, %77 : vector<[4]x[4]xi1>
            %89 = vector.transfer_read %82[%c0, %c0], %cst_0, %88 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x[4]xf32>
            %90 = arith.extf %84 : vector<[4]x[4]x1xf16> to vector<[4]x[4]x1xf32>
            %91 = arith.extf %87 : vector<[4]x[4]x1xf16> to vector<[4]x[4]x1xf32>
            %92 = vector.create_mask %76, %77, %c1 : vector<[4]x[4]x1xi1>
            %93 = vector.mask %92 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %90, %91, %89 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
            %94 = vector.transfer_write %93, %82[%c0, %c0], %88 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, tensor<?x?xf32>
            %95 = bufferization.to_memref %94 : memref<?x?xf32>
            %alloc_22 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
            memref.copy %74, %alloc_22 : memref<32x64xf32> to memref<32x64xf32>
            %subview_23 = memref.subview %alloc_22[%arg19, %arg21] [%78, %79] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %95, %subview_23 : memref<?x?xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %96 = bufferization.to_tensor %alloc_22 : memref<32x64xf32>
            scf.yield %96 : tensor<32x64xf32>
          }
          scf.yield %73 : tensor<32x64xf32>
        }
        scf.yield %72 : tensor<32x64xf32>
      }
      %68 = bufferization.to_memref %67 : memref<32x64xf32>
      %alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%68, %41 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_18 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_19: f32, %out: f32):
        %72 = arith.addf %in, %in_19 : f32
        linalg.yield %72 : f32
      }
      %69 = bufferization.to_tensor %alloc_18 : memref<32x64xf32>
      %70 = arith.addi %arg17, %c16 : index
      %71 = arith.addi %arg18, %27 : index
      scf.yield %69, %70, %71 : tensor<32x64xf32>, index, index
    }
    %29 = bufferization.to_memref %28#0 : memref<32x64xf32>
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %15, %30 : index
    %32 = arith.addi %31, %17 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%29 : memref<32x64xf32>) outs(%alloc_2 : memref<32x64xf16>) {
    ^bb0(%in: f32, %out: f16):
      %41 = arith.truncf %in : f32 to f16
      linalg.yield %41 : f16
    }
    %33 = arith.addi %15, %c32 : index
    %34 = arith.minsi %33, %18 : index
    %35 = arith.subi %34, %15 : index
    %36 = arith.addi %17, %c64 : index
    %37 = arith.minsi %36, %23 : index
    %38 = arith.subi %37, %17 : index
    %39 = arith.minsi %35, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %alloc_2[0, 0] [%39, %40] [1, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %subview_3 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf16, strided<[?, 1], offset: ?>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_3 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}

@@ -32,16 +44,25 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-structured", "--canonicalize", "--triton-arith-to-linalg", "--cse", "--structured-to-memref", "-o", dst_path])
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we revert this? We're in the process of migrating away from the monolith pass, so any future work should ideally be tested using the new modular passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could definitely be done, I my GCC compiler didn't like the modular pass format but I think I just have to re-install

@danikhan632
Copy link
Contributor Author

Also I know that we just merged the arm-workflow runner but I might want to get rid of it since I have working changes for float16 and bfloat16. For quite sometime, triton-shared has been swapping out bf16/fp16 for f32 and I am working on optional support if the current system supports avx512_bf16(x86) or sve-bf16 (arm) or fp16 instructions. Wondering if the runner could changed at some point in the future

@aaronsm aaronsm self-requested a review March 8, 2024 09:51
@danikhan632
Copy link
Contributor Author

getting error when trying to pass the IR through mlir-opt, I tried some of the flags used in the SME example but their for the transform interpreter, anybody got any ideas?

 error: failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal
            %114 = vector.mask %113 { vector.transfer_read %extracted_slice_21[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

I think its got something todo with the flags being passed in

def _ttsharedir_to_llir(ttsharedir: str):
    with tempfile.TemporaryDirectory() as tmpdir:
        ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
        llmlir_path = os.path.join(tmpdir, "ll.mlir")
        llir_path = os.path.join(tmpdir, "ll.ir")
        Path(ttshared_path).write_text(ttsharedir)
        mlir_opt_path = _get_llvm_bin_path("mlir-opt")
        
        # TritonShared-MLIR to LLVM-MLIR
        subprocess.check_call([
            mlir_opt_path,
            ttshared_path,
            "--convert-linalg-to-affine-loops",
            "--eliminate-empty-tensors",
            "--arm-sve-legalize-vector-storage",
            "--allocate-arm-sme-tiles",
            "--empty-tensor-to-alloc-tensor",
            "--one-shot-bufferize=allow-return-allocs-from-loops=true",
            "--lower-affine",
            "--convert-linalg-to-loops",
            "--convert-arm-sme-to-scf",
            "--convert-scf-to-cf",
            "--convert-cf-to-llvm",
            "--convert-arith-to-llvm",
            "--convert-math-to-llvm",
            "--convert-complex-to-llvm",
            "--convert-vector-to-arm-sme",
            "--convert-arm-sme-to-llvm",
            "--convert-index-to-llvm",
            "--memref-expand",
            "-convert-vector-to-llvm=enable-arm-sve",
            "--expand-strided-metadata",
            "--finalize-memref-to-llvm",
            "--convert-func-to-llvm",
            # Lowering memrefs creates more affine.apply ops.
            # Lowering these affine ops again creates further arith ops,
            # so we have to run these two passes again here.
            "--lower-affine",
            "--convert-arith-to-llvm",
            # Remove all unrealized casts created
            "--canonicalize",
            "-o",
            llmlir_path,
        ])

        # LLVM-MLIR to LLVM-IR
        mlir_translate_path = _get_llvm_bin_path("mlir-translate")
        subprocess.check_call([mlir_translate_path, llmlir_path,
            "--mlir-to-llvmir",
            "-o",
            llir_path])
        return Path(llir_path).read_text()

@zhaoshiz
Copy link

zhaoshiz commented Mar 8, 2024

You maybe missing lowering masked vector transfers.
https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L79
The flag is -lower-vector-mask, if that doesn't work, you can call it in C++.
https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp#L101-L104

@danikhan632
Copy link
Contributor Author

danikhan632 commented Mar 8, 2024

You maybe missing lowering masked vector transfers. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L79 The flag is -lower-vector-mask, if that doesn't work, you can call it in C++. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp#L101-L104

that seemed to change the IR but didn't seem to fix the issue completely

this is the IR now being generated

#map = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map1 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map2 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 + s0)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map7 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_0 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_1 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_1, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %39 = arith.addi %arg18, %16 : index
      %40 = arith.remsi %39, %22 : index
      %41 = arith.subi %39, %40 : index
      %42 = arith.addi %40, %c64 : index
      %43 = arith.minsi %42, %22 : index
      %44 = arith.subi %43, %40 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %45 = arith.subi %c64, %44 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %46 = arith.remsi %arg17, %18 : index
      %47 = arith.addi %20, %46 : index
      %48 = arith.subi %47, %arg17 : index
      %49 = arith.divsi %48, %18 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %50 = arith.subi %c32, %49 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %51 = arith.muli %arg15, %c16_i32 : i32
      %52 = arith.subi %arg5, %51 : i32
      %53 = arith.index_cast %52 : i32 to index
      %54 = arith.minsi %53, %c16 : index
      %alloc_8 = memref.alloc() : memref<32x16xf16>
      %55 = arith.cmpi slt, %54, %c16 : index
      scf.if %55 {
        linalg.fill ins(%cst_0 : f16) outs(%alloc_8 : memref<32x16xf16>)
      }
      %56 = arith.minsi %49, %c32 : index
      %57 = arith.subi %c32, %56 : index
      %subview_9 = memref.subview %reinterpret_cast_6[0, 0] [%56, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_7[0, 0] [%57, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%56, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %subview_12 = memref.subview %alloc_8[%56, 0] [%57, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_13 = memref.alloc() : memref<16x64xf16>
      %58 = arith.cmpi slt, %54, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst_0 : f16) outs(%alloc_13 : memref<16x64xf16>)
      }
      %59 = arith.minsi %44, %c64 : index
      %60 = arith.subi %c64, %59 : index
      %subview_14 = memref.subview %reinterpret_cast_4[0, 0] [%54, %59] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_15 = memref.subview %reinterpret_cast_5[0, 0] [%54, %60] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_16 = memref.subview %alloc_13[0, 0] [%54, %59] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %subview_17 = memref.subview %alloc_13[0, %59] [%54, %60] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %subview_14, %subview_16 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %subview_15, %subview_17 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %61 = vector.vscale
      %62 = arith.muli %61, %c4 : index
      %63 = arith.muli %61, %c4 : index
      %alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_18 : memref<32x64xf32> to memref<32x64xf32>
      %64 = scf.for %arg19 = %c0 to %c32 step %62 iter_args(%arg20 = %alloc_18) -> (memref<32x64xf32>) {
        %67 = scf.for %arg21 = %c0 to %c64 step %63 iter_args(%arg22 = %arg20) -> (memref<32x64xf32>) {
          %68 = scf.for %arg23 = %c0 to %c16 step %c1 iter_args(%arg24 = %arg22) -> (memref<32x64xf32>) {
            %69 = affine.min #map(%arg19, %62)
            %70 = affine.min #map1(%arg21, %63)
            %71 = affine.min #map(%arg19, %62)
            %72 = affine.min #map1(%arg21, %63)
            %subview_19 = memref.subview %alloc_8[%arg19, %arg23] [%69, 1] [1, 1] : memref<32x16xf16> to memref<?x1xf16, strided<[16, 1], offset: ?>>
            %subview_20 = memref.subview %alloc_13[%arg23, %arg21] [1, %70] [1, 1] : memref<16x64xf16> to memref<1x?xf16, strided<[64, 1], offset: ?>>
            %subview_21 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %73 = vector.create_mask %69 : vector<[4]xi1>
            %subview_22 = memref.subview %subview_19[0, 0] [%69, 1] [1, 1] : memref<?x1xf16, strided<[16, 1], offset: ?>> to memref<?xf16, #map2>
            %74 = vector.transfer_read %subview_22[%c0], %cst_0, %73 {in_bounds = [true]} : memref<?xf16, #map2>, vector<[4]xf16>
            %75 = vector.shape_cast %74 : vector<[4]xf16> to vector<[4]x1xf16>
            %76 = vector.create_mask %70 : vector<[4]xi1>
            %subview_23 = memref.subview %subview_20[0, 0] [1, %70] [1, 1] : memref<1x?xf16, strided<[64, 1], offset: ?>> to memref<?xf16, #map3>
            %77 = vector.transfer_read %subview_23[%c0], %cst_0, %76 {in_bounds = [true]} : memref<?xf16, #map3>, vector<[4]xf16>
            %78 = vector.shape_cast %77 : vector<[4]xf16> to vector<1x[4]xf16>
            %79 = vector.create_mask %69, %70 : vector<[4]x[4]xi1>
            %80 = vector.transfer_read %subview_21[%c0, %c0], %cst, %79 {in_bounds = [true, true]} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]x[4]xf32>
            %81 = arith.extf %75 : vector<[4]x1xf16> to vector<[4]x1xf32>
            %82 = arith.extf %78 : vector<1x[4]xf16> to vector<1x[4]xf32>
            %83 = vector.create_mask %69, %70, %c1 : vector<[4]x[4]x1xi1>
            %84 = vector.mask %83 { vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %81, %82, %80 : vector<[4]x1xf32>, vector<1x[4]xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
            vector.transfer_write %84, %subview_21[%c0, %c0], %79 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[64, 1], offset: ?>>
            %subview_24 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %subview_21, %subview_24 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            scf.yield %arg24 : memref<32x64xf32>
          }
          scf.yield %68 : memref<32x64xf32>
        }
        scf.yield %67 : memref<32x64xf32>
      }
      linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%64, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%64 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_19: f32, %out: f32):
        %67 = arith.addf %in, %in_19 : f32
        linalg.yield %67 : f32
      }
      %65 = arith.addi %arg17, %c16 : index
      %66 = arith.addi %arg18, %26 : index
      scf.yield %64, %65, %66 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%27#0 : memref<32x64xf32>) outs(%alloc_2 : memref<32x64xf16>) {
    ^bb0(%in: f32, %out: f16):
      %39 = arith.truncf %in : f32 to f16
      linalg.yield %39 : f16
    }
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.subi %32, %14 : index
    %34 = arith.addi %16, %c64 : index
    %35 = arith.minsi %34, %22 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.minsi %33, %c32 : index
    %38 = arith.minsi %36, %c64 : index
    %subview = memref.subview %alloc_2[0, 0] [%37, %38] [1, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %subview_3 = memref.subview %reinterpret_cast[0, 0] [%37, %38] [1, 1] : memref<32x64xf16, strided<[?, 1], offset: ?>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_3 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}

@zhaoshiz
Copy link

zhaoshiz commented Mar 9, 2024

There's a masked vector.contract in the IR

            %117 = vector.create_mask %105, %106, %c1 : vector<[4]x[4]x1xi1>
            %118 = vector.mask %117 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %115, %116, %114 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>

Maybe lowering vector mask can help:
https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L89

@danikhan632
Copy link
Contributor Author

danikhan632 commented Mar 9, 2024

There's a masked vector.contract in the IR

            %117 = vector.create_mask %105, %106, %c1 : vector<[4]x[4]x1xi1>
            %118 = vector.mask %117 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %115, %116, %114 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>

Maybe lowering vector mask can help: https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L89

I figured out the issue is that the outerproduct part seems to be having no effect on the MLIR output. I'm trying to figure out why this does nothing

struct OuterProductVectorizationPass
    : public PassWrapper<OuterProductVectorizationPass,
                         OperationPass<func::FuncOp>> {
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<vector::VectorDialect, func::FuncDialect>();
  }

  void runOnOperation() override {
    func::FuncOp funcOp = getOperation();
    MLIRContext *context = funcOp.getContext();
    RewritePatternSet patterns(context);
    ConversionTarget target(*context);

      // Apply patterns for lowering masked transfers
    transform::ApplyLowerMaskedTransfersPatternsOp lowerMaskedTransfersPatterns;
    lowerMaskedTransfersPatterns.populatePatterns(patterns);

    // Apply patterns for transfer permutation
    transform::ApplyTransferPermutationPatternsOp transferPermutationPatterns;
    transferPermutationPatterns.populatePatterns(patterns);

    // Apply patterns for reduction to contract
    transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns;
    reductionToContractPatterns.populatePatterns(patterns);

    // Apply patterns for lowering contraction using outer product
    transform::ApplyLowerOuterProductPatternsOp lowerOuterProductPatterns;
    lowerOuterProductPatterns.populatePatterns(patterns);

    // Apply patterns for lowering masks
    transform::ApplyLowerMasksPatternsOp lowerMasksPatterns;
    lowerMasksPatterns.populatePatterns(patterns);

    // Apply patterns for rank-reducing subview
    transform::ApplyRankReducingSubviewPatternsOp rankReducingSubviewPatterns;
    rankReducingSubviewPatterns.populatePatterns(patterns);



    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
      return signalPassFailure();
    }

  }
  
};

reductionToContractPatterns.populatePatterns(patterns);

// Apply patterns for lowering contraction using outer product
transform::ApplyLowerOuterProductPatternsOp lowerOuterProductPatterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noe that these patterns lower vector.outerproduct rather than vector.contract to vector.outerproduct: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td#L205

// Apply patterns for reduction to contract
transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns;
reductionToContractPatterns.populatePatterns(patterns);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nhat-nguyen turns out that the whole kernel need to be bufferized before we can run this pass. Can we use structured-to-memref todo this? would make things a lot easier

Copy link
Collaborator

@nhat-nguyen nhat-nguyen Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you can use a combination of --triton-to-structured --canonicalize --triton-arith-to-linalg --structured-to-memref, this is the equivalent of --triton-to-linalg. Although this only converts the loads and stores to memref, you would need to also run the bufferization pass to convert the remaining ops to use memref too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i was a little worried about this, looks like all the tensor ops will have to be lowered too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nhat-nguyen so I think in order to lower to outer product, everything will have to to be lowered to memref, I added in a few off the self bufferization passes and am able to everything expect the bufferization.to_tensor op. Any ideas how to fix this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I'm hearing, bufferizing might be necessary before step #2

         (1)    bufferize here     (2)                      (3)
 linalg.matmul -------------> vector.contract ----> vector.outerproduct -----> arm_sme.fmopa

in that case optimizations may be difficult todo optimizations. Might still write a bufferization pass to make the compiler happy but def not optimal

@banach-space if I wanted to make future optizations could I modify the lowering to also accept tensors and write some lowering logic? Been thinking it over and bufferizing too early could be bad for performance and given that SME is a matmul engine I'd imagine being able to accept tensors might be useful

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I'm hearing, bufferizing might be necessary before step 2

I don't think so, that's still quite high level and there wouldn't be any SME ops yet :)

Been thinking it over and bufferizing too early could be bad for performance and given that SME is a matmul engine I'd imagine being able to accept tensors might be useful

In my view Tensor is too high level/abstract type for the ArmSME dialect. In particular, the hardware won't know what a tensor is. I suggest the following:

     (1)                    (2)                 (3)              bufferize here
 linalg.matmul -----> vector.contract ----> vector.outerproduct -----------------> arm_sme.fmopa

This would be consistent with what we currently do in MLIR:

As you can see, bufferization happens after lowering vector.contract to vector.outerproduct.

Btw, I deliberately made the distinction into vector.contract and vector.outerproduct as that's representative of how the Vector Dialect splits Ops into "high level" (e.g. vector.contract) and "low level" (e.g. vector.outerproduct) vector ops. Here's an overview of the vectoriser n the context of SME:

Thanks again for working on this - hope this helps :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The video makes sense but a just want to clarify how the passes would be handled
This is what should happen

  PassPipelineRegistration<> smeConversionPipeline(
      "sme-conversion",
      "Converts linalg.matmul to a more optimized form using SME",
      [](OpPassManager &pm) {
        pm.addPass(createMatmulTileConversionPass(true));   //tiling and vectorizing  linalg.matmul
        pm.addPass(createOuterProductVectorizationPass());  // lowering vector.contract to vector.outerproduct
        pm.addPass(createLinalgBufferizePass()); //bufferization happens here
      });

let me know if this is accurate

and ofc happy to work on this :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, but I haven't checked the names of the actual passes. Also, you may need to configure some of these passes to do what you want.

I suggest implementing this step-by-step. At every step you can compare against the output that the Transform Dialect sequence generates - if there's divergence then that might mean trouble :) (but not necessarily)

Note that in your list of passes there's nothing SVE/SME specific, so you won't be lowering to SME just yet. That's fine - baby steps ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just updated, it looks mostly bufferized

@danikhan632
Copy link
Contributor Author

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

If you look at the snippet, you can see that cf.br is still present. @banach-space could it some complication with
--convert-arm-sme-to-scf?

ll.mlir:75:5: error: Dialect `cf' not found for custom op 'cf.br' 
    cf.br ^bb1(%37 : index)
    ^
.../ll.mlir:75:5: note: Registered dialects: acc, amx, arm_neon, arm_sme, arm_sve, builtin, dlti, func, gpu, llvm, nvvm, omp, rocdl, spirv, x86vector ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq...

I ran with and without and as you can see in that cf.br is still present when it should have been lowered

image

current output from sme-opt:

#map = affine_map<()[s0] -> (s0 * 16)>
#map1 = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map2 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map3 = affine_map<()[s0, s1] -> (s0 * 16 + s1)>
#map4 = affine_map<()[s0, s1] -> (s0 * 64 + s1)>
#map5 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map6 = affine_map<(d0)[s0] -> (d0 + s0)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %cst = arith.constant dense<0.000000e+00> : vector<[4]xf16>
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_1 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    scf.for %arg15 = %c0 to %c32 step %c1 {
      scf.for %arg16 = %c0 to %c64 step %c1 {
        memref.store %cst_0, %alloc[%arg15, %arg16] : memref<32x64xf32>
      }
    }
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_2 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_2, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %39 = arith.addi %arg18, %16 : index
      %40 = arith.remsi %39, %22 : index
      %41 = arith.subi %39, %40 : index
      %42 = arith.addi %40, %c64 : index
      %43 = arith.minsi %42, %22 : index
      %44 = arith.subi %43, %40 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %45 = arith.subi %c64, %44 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %46 = arith.remsi %arg17, %18 : index
      %47 = arith.addi %20, %46 : index
      %48 = arith.subi %47, %arg17 : index
      %49 = arith.divsi %48, %18 : index
      %reinterpret_cast_8 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %50 = arith.subi %c32, %49 : index
      %reinterpret_cast_9 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %51 = arith.muli %arg15, %c16_i32 : i32
      %52 = arith.subi %arg5, %51 : i32
      %53 = arith.index_cast %52 : i32 to index
      %54 = arith.minsi %53, %c16 : index
      %alloc_10 = memref.alloc() : memref<32x16xf16>
      %55 = arith.cmpi slt, %54, %c16 : index
      scf.if %55 {
        scf.for %arg19 = %c0 to %c32 step %c1 {
          scf.for %arg20 = %c0 to %c16 step %c1 {
            memref.store %cst_1, %alloc_10[%arg19, %arg20] : memref<32x16xf16>
          }
        }
      }
      %56 = arith.minsi %49, %c32 : index
      %57 = arith.subi %c32, %56 : index
      %base_buffer_11, %offset_12, %sizes_13:2, %strides_14:2 = memref.extract_strided_metadata %reinterpret_cast_8 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_15 = memref.reinterpret_cast %base_buffer_11 to offset: [%offset_12], sizes: [%56, %54], strides: [%strides_14#0, %strides_14#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %base_buffer_16, %offset_17, %sizes_18:2, %strides_19:2 = memref.extract_strided_metadata %reinterpret_cast_9 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_20 = memref.reinterpret_cast %base_buffer_16 to offset: [%offset_17], sizes: [%57, %54], strides: [%strides_19#0, %strides_19#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %reinterpret_cast_21 = memref.reinterpret_cast %alloc_10 to offset: [0], sizes: [%56, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %58 = affine.apply #map()[%56]
      %reinterpret_cast_22 = memref.reinterpret_cast %alloc_10 to offset: [%58], sizes: [%57, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %reinterpret_cast_15, %reinterpret_cast_21 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %reinterpret_cast_20, %reinterpret_cast_22 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_23 = memref.alloc() : memref<16x64xf16>
      %59 = arith.cmpi slt, %54, %c16 : index
      scf.if %59 {
        scf.for %arg19 = %c0 to %c16 step %c1 {
          scf.for %arg20 = %c0 to %c64 step %c1 {
            memref.store %cst_1, %alloc_23[%arg19, %arg20] : memref<16x64xf16>
          }
        }
      }
      %60 = arith.minsi %44, %c64 : index
      %61 = arith.subi %c64, %60 : index
      %base_buffer_24, %offset_25, %sizes_26:2, %strides_27:2 = memref.extract_strided_metadata %reinterpret_cast_6 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_28 = memref.reinterpret_cast %base_buffer_24 to offset: [%offset_25], sizes: [%54, %60], strides: [%strides_27#0, %strides_27#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %base_buffer_29, %offset_30, %sizes_31:2, %strides_32:2 = memref.extract_strided_metadata %reinterpret_cast_7 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_33 = memref.reinterpret_cast %base_buffer_29 to offset: [%offset_30], sizes: [%54, %61], strides: [%strides_32#0, %strides_32#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %reinterpret_cast_34 = memref.reinterpret_cast %alloc_23 to offset: [0], sizes: [%54, %60], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %reinterpret_cast_35 = memref.reinterpret_cast %alloc_23 to offset: [%60], sizes: [%54, %61], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %reinterpret_cast_28, %reinterpret_cast_34 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %reinterpret_cast_33, %reinterpret_cast_35 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %62 = vector.vscale
      %63 = arith.muli %62, %c4 : index
      %64 = arith.muli %62, %c4 : index
      %alloc_36 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_36 : memref<32x64xf32> to memref<32x64xf32>
      scf.for %arg19 = %c0 to %c32 step %63 {
        scf.for %arg20 = %c0 to %c64 step %64 {
          scf.for %arg21 = %c0 to %c16 step %c1 {
            %67 = affine.min #map1(%arg19, %63)
            %68 = affine.min #map2(%arg20, %64)
            %69 = affine.min #map1(%arg19, %63)
            %70 = affine.min #map2(%arg20, %64)
            %71 = affine.apply #map3()[%arg19, %arg21]
            %72 = affine.apply #map4()[%arg21, %arg20]
            %73 = affine.apply #map4()[%arg19, %arg20]
            %reinterpret_cast_37 = memref.reinterpret_cast %alloc_36 to offset: [%73], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %74 = vector.create_mask %67 : vector<[4]xi1>
            %reinterpret_cast_38 = memref.reinterpret_cast %alloc_10 to offset: [%71], sizes: [%67], strides: [16] : memref<32x16xf16> to memref<?xf16, #map5>
            %75 = vector.vscale
            %76 = arith.muli %75, %c4 : index
            %77 = scf.for %arg22 = %c0 to %76 step %c1 iter_args(%arg23 = %cst) -> (vector<[4]xf16>) {
              %108 = vector.extractelement %74[%arg22 : index] : vector<[4]xi1>
              %109 = scf.if %108 -> (vector<[4]xf16>) {
                %110 = memref.load %reinterpret_cast_38[%arg22] : memref<?xf16, #map5>
                %111 = vector.insertelement %110, %arg23[%arg22 : index] : vector<[4]xf16>
                scf.yield %111 : vector<[4]xf16>
              } else {
                scf.yield %arg23 : vector<[4]xf16>
              }
              scf.yield %109 : vector<[4]xf16>
            }
            %78 = vector.shape_cast %77 : vector<[4]xf16> to vector<[4]x1xf16>
            %79 = vector.create_mask %68 : vector<[4]xi1>
            %reinterpret_cast_39 = memref.reinterpret_cast %alloc_23 to offset: [%72], sizes: [%68], strides: [1] : memref<16x64xf16> to memref<?xf16, #map6>
            %80 = vector.transfer_read %reinterpret_cast_39[%c0], %cst_1, %79 {in_bounds = [true]} : memref<?xf16, #map6>, vector<[4]xf16>
            %81 = vector.shape_cast %80 : vector<[4]xf16> to vector<1x[4]xf16>
            %82 = vector.create_mask %67, %68 : vector<[4]x[4]xi1>
            %83 = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
            %c4_40 = arith.constant 4 : index
            %84 = vector.vscale
            %85 = arith.muli %c4_40, %84 : index
            %86 = arith.index_cast %67 : index to i64
            %87 = arith.index_cast %85 : index to i64
            %88 = arith.minsi %86, %87 : i64
            %89 = arith.index_cast %88 : i64 to index
            %90 = vector.create_mask %68 : vector<[4]xi1>
            %c0_41 = arith.constant 0 : index
            %c1_42 = arith.constant 1 : index
            %91 = scf.for %arg22 = %c0_41 to %89 step %c1_42 iter_args(%arg23 = %83) -> (vector<[4]x[4]xf32>) {
              %108 = arith.addi %c0, %arg22 : index
              %109 = arm_sme.load_tile_slice %reinterpret_cast_37[%108, %c0], %90, %arg23, %arg22 {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
              scf.yield %109 : vector<[4]x[4]xf32>
            }
            %92 = arith.extf %78 : vector<[4]x1xf16> to vector<[4]x1xf32>
            %93 = arith.extf %81 : vector<1x[4]xf16> to vector<1x[4]xf32>
            %94 = vector.transpose %92, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
            %95 = vector.extract %94[0] : vector<[4]xf32> from vector<1x[4]xf32>
            %96 = vector.extract %93[0] : vector<[4]xf32> from vector<1x[4]xf32>
            %97 = vector.create_mask %67 : vector<[4]xi1>
            %98 = vector.create_mask %68 : vector<[4]xi1>
            %99 = arm_sme.outerproduct %95, %96 acc(%91) masks(%97, %98) {tile_id = 0 : i32} : vector<[4]xf32>, vector<[4]xf32>
            %c4_43 = arith.constant 4 : index
            %100 = vector.vscale
            %101 = arith.muli %c4_43, %100 : index
            %102 = arith.index_cast %67 : index to i64
            %103 = arith.index_cast %101 : index to i64
            %104 = arith.minsi %102, %103 : i64
            %105 = arith.index_cast %104 : i64 to index
            %106 = vector.create_mask %68 : vector<[4]xi1>
            %c0_44 = arith.constant 0 : index
            %c1_45 = arith.constant 1 : index
            scf.for %arg22 = %c0_44 to %105 step %c1_45 {
              %108 = arith.addi %c0, %arg22 : index
              arm_sme.store_tile_slice %99, %arg22, %106, %reinterpret_cast_37[%108, %c0] {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
            }
            %107 = affine.apply #map4()[%arg19, %arg20]
            %reinterpret_cast_46 = memref.reinterpret_cast %alloc_36 to offset: [%107], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %reinterpret_cast_37, %reinterpret_cast_46 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
          }
        }
      }
      scf.for %arg19 = %c0 to %c32 step %c1 {
        scf.for %arg20 = %c0 to %c64 step %c1 {
          %67 = memref.load %alloc_36[%arg19, %arg20] : memref<32x64xf32>
          %68 = memref.load %arg16[%arg19, %arg20] : memref<32x64xf32>
          %69 = arith.addf %67, %68 : f32
          memref.store %69, %alloc_36[%arg19, %arg20] : memref<32x64xf32>
        }
      }
      %65 = arith.addi %arg17, %c16 : index
      %66 = arith.addi %arg18, %26 : index
      scf.yield %alloc_36, %65, %66 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    scf.for %arg15 = %c0 to %c32 step %c1 {
      scf.for %arg16 = %c0 to %c64 step %c1 {
        %39 = memref.load %27#0[%arg15, %arg16] : memref<32x64xf32>
        %40 = arith.truncf %39 : f32 to f16
        memref.store %40, %alloc_3[%arg15, %arg16] : memref<32x64xf16>
      }
    }
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.subi %32, %14 : index
    %34 = arith.addi %16, %c64 : index
    %35 = arith.minsi %34, %22 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.minsi %33, %c32 : index
    %38 = arith.minsi %36, %c64 : index
    %reinterpret_cast_4 = memref.reinterpret_cast %alloc_3 to offset: [0], sizes: [%37, %38], strides: [64, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %reinterpret_cast : memref<32x64xf16, strided<[?, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
    %reinterpret_cast_5 = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [%37, %38], strides: [%strides#0, 1] : memref<f16> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %reinterpret_cast_4, %reinterpret_cast_5 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}

function that compiles the kernel above:

def _ttsharedir_to_llir(ttsharedir: str):
    with tempfile.TemporaryDirectory() as tmpdir:
        ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
        llmlir_path = os.path.join(tmpdir, "ll.mlir")
        llir_path = os.path.join(tmpdir, "ll.ir")
        Path(ttshared_path).write_text(ttsharedir)
        mlir_opt_path = _get_llvm_bin_path("mlir-opt")
        # TritonShared-MLIR to LLVM-MLIR
        subprocess.check_call([
            mlir_opt_path,
            ttshared_path,
            "--one-shot-bufferize=allow-return-allocs-from-loops=true",
            "--convert-arm-sme-to-llvm", 
            "--convert-vector-to-llvm=enable-arm-sve",
            "--convert-arith-to-llvm",
            "--convert-math-to-llvm",
            "--convert-complex-to-llvm",
            "--convert-func-to-llvm",
            "--convert-index-to-llvm",
            "--finalize-memref-to-llvm",
            "--convert-scf-to-cf",
            "--convert-cf-to-llvm", 
            "-o", llmlir_path
        ])
        # LLVM-MLIR to LLVM-IR
        mlir_translate_path = _get_llvm_bin_path("mlir-translate")
        subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path])
        return Path(llir_path).read_text()

output kernel before mlir-translate:

sme_matmul_lowered.mlir.txt

@danikhan632 danikhan632 requested a review from banach-space April 2, 2024 14:22
@MacDue
Copy link

MacDue commented Apr 2, 2024

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

@danikhan632
Copy link
Contributor Author

danikhan632 commented Apr 2, 2024

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

yea I might have forgotten that we are going from f16 ->f32

understood I guess the tiling logic has to be a bit different since this kernel uses a f32 acculumator.
btw here is orginial kernel for refrence btw before any sme/llvm lowerings are applied

edit:

was kind of confused about the outer-product-fusion thing, turns out these are pretty new and not in llvm commit 4017f04e that current triton branch uses @nhat-nguyen can this be bumped to triton hash ea9777d?

@danikhan632
Copy link
Contributor Author

danikhan632 commented Apr 3, 2024

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

got this working now too, have some concerns about future when I have to change dims from 1 -> 2 for widening but that can be worried about later.

fmopa2_way is being produced right now
kernel.mlir.txt

have a minor issue with this:

ll.mlir:8:10: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>
         ^

want to know if you have any idea how to remedy this and is the caused by something SME related or is it just an triton-shared/mlir thing?

here my first mlir conversion before lowerings to llvm, think the order is right

        subprocess.check_call([mlir_opt_path, sme_first_pass,
        "--canonicalize", 
            "--eliminate-empty-tensors",
            "--convert-linalg-to-loops",
            "--empty-tensor-to-alloc-tensor",
            "--expand-strided-metadata",
            "--arm-sme-vector-legalization",
            "--convert-vector-to-arm-sme",
            "--arm-sme-outer-product-fusion",
            "--arm-sve-legalize-vector-storage",
            "--convert-arith-to-arm-sme",
            "--allocate-arm-sme-tiles",
            "--convert-arm-sme-to-scf",
            "--convert-vector-to-scf",
            "-o",
            mlir_sme_pass])

@nhat-nguyen
Copy link
Collaborator

@danikhan632 unrealized_conversion_cast ops are inserted automatically by TypeConverters during the dialect conversion when resulting types are incompatible. This is unrelated to triton-shared. One way to debug this is to first find out at which pass these unrealized_conversion_cast ops start appearing.

@danikhan632
Copy link
Contributor Author

danikhan632 commented Apr 3, 2024

@danikhan632 unrealized_conversion_cast ops are inserted automatically by TypeConverters during the dialect conversion when resulting types are incompatible. This is unrelated to triton-shared. One way to debug this is to first find out at which pass these unrealized_conversion_cast ops start appearing.

I figured that much, think its got something todo with the way inputs are passed

  llvm.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: i64, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: i64, %arg5: !llvm.ptr, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %1 = llvm.insertvalue %arg4, %0[0] : !llvm.struct<(i64, ptr)> 
    %2 = llvm.insertvalue %arg5, %1[1] : !llvm.struct<(i64, ptr)> 
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>

like I think the kernel is expecting an i64 value and a pointer to the inputs, but it gets a memref

@zhaoshiz
Copy link

zhaoshiz commented Apr 4, 2024

  llvm.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: i64, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: i64, %arg5: !llvm.ptr, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %1 = llvm.insertvalue %arg4, %0[0] : !llvm.struct<(i64, ptr)> 
    %2 = llvm.insertvalue %arg5, %1[1] : !llvm.struct<(i64, ptr)> 
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>

I see unrealized_conversion_cast errors when lowering to the llvm dialect. In my case, it's caused by that the user of this cast (%3 above) is not lowered to llvm dialect. I would check the dialect/op of the user and try find the pass(es) to lower it.

@danikhan632
Copy link
Contributor Author

zhaoshiz
I think the is issue is with the memref allocation of scalable vectors here, any ideas on how to fix?

...
%62 = arith.muli %vscale, %c4 : index
%63 = arith.muli %vscale, %c4 : index
%alloc_37 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_37 : memref<32x64xf32> to memref<32x64xf32>
scf.for %arg19 = %c0 to %c32 step %62 {
  scf.for %arg20 = %c0 to %c64 step %63 {
    scf.for %arg21 = %c0 to %c16 step %c2 {
      %alloca = memref.alloca() : memref<vector<2x[4]xf16>>
      %alloca_38 = memref.alloca() : memref<vector<2x[4]xi1>>

...


AFTER MLIR PASSES

  %33 = "arith.constant"() <{value = 2 : index}> : () -> index
  %34 = "builtin.unrealized_conversion_cast"(%33) : (index) -> i64
  %35 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
  %36 = "arith.constant"() <{value = -1 : index}> : () -> index
  %37 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
  %38 = "builtin.unrealized_conversion_cast"(%37) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
  %39 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
  %40 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
  %41 = "llvm.mlir.constant"() <{value = 32 : index}> : () -> i64

...
AND MORE MLIR PASSES  
  %33 = "arith.constant"() <{value = 64 : i32}> : () -> i32
  %34 = "arith.constant"() <{value = 32 : i32}> : () -> i32
  %35 = "arith.constant"() <{value = 8 : i32}> : () -> i32
  %36 = "arith.constant"() <{value = 4 : index}> : () -> index
  %37 = "arith.constant"() <{value = 2 : index}> : () -> index
  %38 = "builtin.unrealized_conversion_cast"(%37) : (index) -> i64
  %39 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
  %40 = "arith.constant"() <{value = -1 : index}> : () -> index
  %41 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
  %42 = "builtin.unrealized_conversion_cast"(%41) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
  %43 = "builtin.unrealized_conversion_cast"(%24) : (index) -> i64


  <unknown>:0: error: failed to legalize operation 'builtin.unrealized_conversion_cast' that was explicitly marked illegal
<unknown>:0: note: see current operation: %38 = "builtin.unrealized_conversion_cast"(%37) : (i64) -> index

@MacDue
Copy link

MacDue commented Apr 5, 2024

I think it's more helpful to look at the users of an unrealized_conversion_cast rather than the cast (especially when posting snippets like the above). The users will be the thing that's keeping the casts around (likely because they've not been lowered correctly).

It looks to me like the arith dialect has not been lowered, but also stuff like making allocas for predicates (i.e. memref<vector<2x[4]xi1>>), is something you generally want to avoid. But if you have to keep them around, they need to be legalised by -arm-sve-legalize-vector-storage (but I don't think that's the cause of the issue here).

@danikhan632
Copy link
Contributor Author

I think it's more helpful to look at the users of an unrealized_conversion_cast rather than the cast (especially when posting snippets like the above). The users will be the thing that's keeping the casts around (likely because they've has not been lowered correctly).

It looks to me like the arith dialect has not been lowered, but also stuff like making allocas for predicates (i.e. memref<vector<2x[4]xi1>>), is something you generally want to avoid. But if you have to keep them around, they need to be legalised by -arm-sve-legalize-vector-storage (but I don't think that's the cause of the issue here).

yeah I don't think '--convert-arith-to-arm-sme' is really doing anything here, I wanted to vet btw that the kernel that I generated is legitimate and that the only thing that I should have to do is run it through mlir-opt and then through mlir-translate and it should be fine.

I also figured that memref<vector<2x[4]xi1>> is not great since these vectors sizes aren't known till run time.

matmul.mlir.txt

@MacDue
Copy link

MacDue commented Apr 5, 2024

I think the allocas for the scalable vectors come from using default lowering of --convert-vector-to-scf. If you do --convert-vector-to-scf=full-unroll, those should be avoided.

@danikhan632
Copy link
Contributor Author

I think the allocas for the scalable vectors come from using default lowering of --convert-vector-to-scf. If you do --convert-vector-to-scf=full-unroll, those should be avoided.
it lowers to llir successfully now, this issue it when compiling using llc
any ideas why?
sme_matmul_lowered.llir.txt

LLVM ERROR: Cannot select: t155: i64 = vscale Constant:i64<1024>
  t154: i64 = Constant<1024>
In function: matmul_kernel_0d1d2d34567c89c1011c
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/green/.triton/llvm/llvm-6f44bb77-ubuntu-x64/bin/llc /tmp/tmp7lnmobsz/kernel.ll -o /tmp/tmp7lnmobsz/kernel.o
1.      Running pass 'Function Pass Manager' on module '/tmp/tmp7lnmobsz/kernel.ll'.
2.      Running pass 'X86 DAG->DAG Instruction Selection' on function '@matmul_kernel_0d1d2d34567c89c1011c'

@MacDue
Copy link

MacDue commented Apr 6, 2024

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

@danikhan632
Copy link
Contributor Author

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

ah I see, I think this is where the sme userspace emulator is needed,

I think this is correct, going to switch over to my arm system to test it

def _llir_to_bin(llir: str, metadata):
    pattern = r"define void @(\w+)\(.+"
    matches = re.findall(pattern, llir)
    assert len(matches) == 1
    metadata["name"] = matches[0]
    with tempfile.TemporaryDirectory() as tmpdir:
        src_path = os.path.join(tmpdir, "kernel.ll")
        dst_path = os.path.join(tmpdir, "kernel.o")
        Path(src_path).write_text(llir)
        llc_path = _get_llvm_bin_path("llc")
        subprocess.check_call(["/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])
        # Actually it's text-format assembly.  Use read_text().
        return Path(dst_path).read_text()

@danikhan632
Copy link
Contributor Author

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

I did that and get this error,

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

output:

/tmp/tmp1j2jbysw/kernel.s: Assembler messages:
/tmp/tmp1j2jbysw/kernel.s:126: Error: selected processor does not support `rdvl x8,#1'
/tmp/tmp1j2jbysw/kernel.s:130: Error: selected processor does not support `cntw x24'
/tmp/tmp1j2jbysw/kernel.s:409: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:410: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:417: Error: selected processor does not support `incw x22'
/tmp/tmp1j2jbysw/kernel.s:420: Error: selected processor does not support `addvl x8,x8,#8'
/tmp/tmp1j2jbysw/kernel.s:438: Error: selected processor does not support `incw x25'
/tmp/tmp1j2jbysw/kernel.s:439: Error: selected processor does not support `addvl x20,x20,#1'
/tmp/tmp1j2jbysw/kernel.s:486: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:490: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:513: Error: selected processor does not support `mov z0.s,w9'
/tmp/tmp1j2jbysw/kernel.s:514: Error: selected processor does not support `cmpgt p0.s,p2/z,z0.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:516: Error: selected processor does not support `ld1h {z0.s},p0/z,[x13,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:519: Error: selected processor does not support `ld1h {z1.s},p0/z,[x11,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:522: Error: unknown mnemonic `zero' -- `zero {za0.s}'
/tmp/tmp1j2jbysw/kernel.s:530: Error: operand 1 must be a list of SVE vector registers -- `ld1w {za0h.s[w12,0]},p0/z,[x13]'
/tmp/tmp1j2jbysw/kernel.s:536: Error: selected processor does not support `mov z2.s,w8'
/tmp/tmp1j2jbysw/kernel.s:539: Error: selected processor does not support `cmpgt p0.s,p2/z,z2.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:540: Error: selected processor does not support `mov z2.h,#0'
/tmp/tmp1j2jbysw/kernel.s:541: Error: selected processor does not support `mov z3.s,p0/z,#1'
/tmp/tmp1j2jbysw/kernel.s:554: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:555: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:558: Error: selected processor does not support `mov z4.s,w11'
/tmp/tmp1j2jbysw/kernel.s:559: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:561: Error: selected processor does not support `mov z2.h,p0/m,h4'
/tmp/tmp1j2jbysw/kernel.s:564: Error: selected processor does not support `mov z4.h,#0'
/tmp/tmp1j2jbysw/kernel.s:579: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:580: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:583: Error: selected processor does not support `mov z5.s,w11'
/tmp/tmp1j2jbysw/kernel.s:584: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z5.s'
/tmp/tmp1j2jbysw/kernel.s:586: Error: selected processor does not support `mov z4.h,p0/m,h5'
/tmp/tmp1j2jbysw/kernel.s:589: Error: selected processor does not support `mov z3.s,w8'
/tmp/tmp1j2jbysw/kernel.s:590: Error: selected processor does not support `mov z5.s,w9'
/tmp/tmp1j2jbysw/kernel.s:593: Error: selected processor does not support `cmpgt p1.s,p2/z,z3.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:594: Error: selected processor does not support `cmpgt p0.s,p2/z,z5.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:595: Error: selected processor does not support `zip2 z3.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:596: Error: selected processor does not support `zip1 z2.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:597: Error: selected processor does not support `zip2 z4.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:598: Error: selected processor does not support `zip1 z0.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:600: Error: selected processor does not support `zip2 p2.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:602: Error: selected processor does not support `zip1 p1.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:603: Error: selected processor does not support `zip2 p3.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:604: Error: selected processor does not support `uzp1 z1.h,z2.h,z3.h'
/tmp/tmp1j2jbysw/kernel.s:605: Error: selected processor does not support `uzp1 z0.h,z0.h,z4.h'
/tmp/tmp1j2jbysw/kernel.s:606: Error: selected processor does not support `zip1 p4.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:607: Error: selected processor does not support `uzp1 p1.h,p1.h,p2.h'
/tmp/tmp1j2jbysw/kernel.s:608: Error: selected processor does not support `uzp1 p2.h,p4.h,p3.h'
/tmp/tmp1j2jbysw/kernel.s:609: Error: unknown mnemonic `fmopa' -- `fmopa za0.s,p1/m,p2/m,z1.h,z0.h'
/tmp/tmp1j2jbysw/kernel.s:617: Error: operand 1 must be a list of SVE vector registers -- `st1w {za0h.s[w12,0]},p0,[x13]'

I took a look at the sme_matmul.mlir and trying to figure out where these shared libs are

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s

@banach-space
Copy link

Generic Advice

It's best to run the tests as part of the build process of MLIR (or afterwards) and then to copy the build commands from tests. CMake flags to run the SME integration tests are documented here:

  -DMLIR_INCLUDE_INTEGRATION_TESTS=On
  -DMLIR_RUN_ARM_SME_TESTS=On
  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 

Then, during/after the build, you can either run all the tests:

ninja check-mlir

or just selected integration tests:

cd <llvm-build-dir>
# Please adjust paths to match your system
bin/llvm-lit -va ../../mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

Note that I am using -va - this will make LIT print the RUN commands. You can extract what's needed from those RUN lines. I would use these as your reference commands.

I would make sure that these tests work for you before trying to run things manually.

Specific advice

As you have noted, the tests will contain sth like this:

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s

%mcr_aarch64_cmd is a convenience wrapper for mlir-cpu-runner:

This is important - it means that ^^^ defines flags to be passed to mlir-cpu-runner. However, you are passing these flags to qemu-aarch64-static:

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

That's incorrect and won't work :)

Now, it also looks like you are passing llc to qemu-aarch64-static (guessing based on llc_path above). That's not required (*) - llc is a driver for the LLVM backend that lowers LLVM IR to Machine Code.

Also, we don't really use llc for the integration tests. Instead, we rely on mlir-cpu-runner to drive that part of the compilation (the name is a bit confusing).

As for -march=aarch64 -mattr="+sve,+sme", those flags are passed to mlir-cpu-runner (i.e. %mcr_aarch64_cmd) - that's to inform the compilation pipeline (driven by mlir-cpu-runner) to target SVE.

Suggestion

  1. Try running SME integration tests in MLIR. This will give you a working reference.
  2. Share your build step and try to run your binary from command line rather than via Python. In particular, what is it that you are trying to run? An MLIR file? An LLVM IR file? A binary? What do you get at the end of your compilation?

HTH :)
-Andrzej

(*) Unless you've cross-compiled it, but I highly doubt it.

@danikhan632
Copy link
Contributor Author

Generic Advice

It's best to run the tests as part of the build process of MLIR (or afterwards) and then to copy the build commands from tests. CMake flags to run the SME integration tests are documented here:

  -DMLIR_INCLUDE_INTEGRATION_TESTS=On
  -DMLIR_RUN_ARM_SME_TESTS=On
  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 

Then, during/after the build, you can either run all the tests:

ninja check-mlir

or just selected integration tests:

cd <llvm-build-dir>
# Please adjust paths to match your system
bin/llvm-lit -va ../../mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

Note that I am using -va - this will make LIT print the RUN commands. You can extract what's needed from those RUN lines. I would use these as your reference commands.

I would make sure that these tests work for you before trying to run things manually.

Specific advice

As you have noted, the tests will contain sth like this:

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s

%mcr_aarch64_cmd is a convenience wrapper for mlir-cpu-runner:

This is important - it means that ^^^ defines flags to be passed to mlir-cpu-runner. However, you are passing these flags to qemu-aarch64-static:

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

That's incorrect and won't work :)

Now, it also looks like you are passing llc to qemu-aarch64-static (guessing based on llc_path above). That's not required (*) - llc is a driver for the LLVM backend that lowers LLVM IR to Machine Code.

Also, we don't really use llc for the integration tests. Instead, we rely on mlir-cpu-runner to drive that part of the compilation (the name is a bit confusing).

As for -march=aarch64 -mattr="+sve,+sme", those flags are passed to mlir-cpu-runner (i.e. %mcr_aarch64_cmd) - that's to inform the compilation pipeline (driven by mlir-cpu-runner) to target SVE.

Suggestion

  1. Try running SME integration tests in MLIR. This will give you a working reference.
  2. Share your build step and try to run your binary from command line rather than via Python. In particular, what is it that you are trying to run? An MLIR file? An LLVM IR file? A binary? What do you get at the end of your compilation?

HTH :) -Andrzej

(*) Unless you've cross-compiled it, but I highly doubt it.

got it, I'm trying to pass llir through llc to compile to binary. I think mlir-cpu-runner will be good for IR tests but looking to run this E2E.

Not sure if the MLIR CPU runner can do this

  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 
  

also is this the instruction emulator?

https://developer.arm.com/Tools%20and%20Software/Arm%20Instruction%20Emulator

@banach-space
Copy link

also is this the instruction emulator?

ArmIE is one emulator, but based on the website it only support SVE and SVE2 (so no SME):

Arm Instruction Emulator (ArmIE) emulates Scalable Vector Extension (SVE) and SVE2 instructions on AArch64 platforms.

QEMU does support SME: https://qemu-project.gitlab.io/qemu/system/arm/cpu-features.html

Btw, I forgot to answer your other question:

I took a look at the sme_matmul.mlir and trying to figure out where these shared libs are

These are MLIR runtime libs - you will find them in the LLVM build directory under the lib directory. Also, note that:

  • on linux,%mlir_runner_utils expands to libmlir_runner_utils.so.

:)

@steplong
Copy link

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

@danikhan632
Copy link
Contributor Author

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats.
Only on ubuntu 22.04 not 20.04
Qemu must be built from source
some more changes I need to push but yes I have confirmed this to work

@steplong
Copy link

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats. Only on ubuntu 22.04 not 20.04 Qemu must be built from source some more changes I need to push but yes I have confirmed this to work

Could you share those changes? Right now, I modified the generated launcher.cpp to build as an executable and then linking with the generated kernel.o, passing two tensors to matmul_kernel, and then comparing the output to an expected result.

@danikhan632
Copy link
Contributor Author

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats. Only on ubuntu 22.04 not 20.04 Qemu must be built from source some more changes I need to push but yes I have confirmed this to work

Could you share those changes? Right now, I modified the generated launcher.cpp to build as an executable and then linking with the generated kernel.o, passing two tensors to matmul_kernel, and then comparing the output to an expected result.

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

@steplong
Copy link

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

@danikhan632
Copy link
Contributor Author

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

@steplong
Copy link

steplong commented Jul 1, 2024

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

diff --git a/backend/compiler.py b/backend/compiler.py
index a0965e9..0535430 100644
--- a/backend/compiler.py
+++ b/backend/compiler.py
@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):
     assert len(matches) == 1
     metadata["name"] = matches[0]
     with tempfile.TemporaryDirectory() as tmpdir:
-        src_path = os.path.join(tmpdir, "kernel.ll")
-        dst_path = os.path.join(tmpdir, "kernel.o")
+        src_path = os.path.join(os.getcwd(), "kernel.ll")
+        dst_path = os.path.join(os.getcwd(), "kernel.o")
         Path(src_path).write_text(llir)
         llc_path = _get_llvm_bin_path("llc")
         if  _get_triton_SME_path() == "":
diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py
index caa072c..7ef952d 100644
--- a/python/examples/test_matmul.py
+++ b/python/examples/test_matmul.py
@@ -85,14 +85,14 @@ def matmul_kernel(
     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
     # of fp32 values for higher accuracy.
     # `accumulator` will be converted back to fp16 after the loop.
-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
         # Load the next block of A and B, generate a mask by checking the K dimension.
         # If it is out of bounds, set it to 0.
         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        accumulator += tl.dot(a, b).to(tl.float16)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -100,7 +100,7 @@ def matmul_kernel(
     # while the accumulator is still in FP32!
     if ACTIVATION == "leaky_relu":
         accumulator = leaky_relu(accumulator)
-    c = accumulator.to(tl.float32)
+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

@danikhan632
Copy link
Contributor Author

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16


diff --git a/backend/compiler.py b/backend/compiler.py

index a0965e9..0535430 100644

--- a/backend/compiler.py

+++ b/backend/compiler.py

@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):

     assert len(matches) == 1

     metadata["name"] = matches[0]

     with tempfile.TemporaryDirectory() as tmpdir:

-        src_path = os.path.join(tmpdir, "kernel.ll")

-        dst_path = os.path.join(tmpdir, "kernel.o")

+        src_path = os.path.join(os.getcwd(), "kernel.ll")

+        dst_path = os.path.join(os.getcwd(), "kernel.o")

         Path(src_path).write_text(llir)

         llc_path = _get_llvm_bin_path("llc")

         if  _get_triton_SME_path() == "":

diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py

index caa072c..7ef952d 100644

--- a/python/examples/test_matmul.py

+++ b/python/examples/test_matmul.py

@@ -85,14 +85,14 @@ def matmul_kernel(

     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block

     # of fp32 values for higher accuracy.

     # `accumulator` will be converted back to fp16 after the loop.

-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)

     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

         # Load the next block of A and B, generate a mask by checking the K dimension.

         # If it is out of bounds, set it to 0.

         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)

         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

         # We accumulate along the K dimension.

-        accumulator += tl.dot(a, b)

+        accumulator += tl.dot(a, b).to(tl.float16)

         # Advance the ptrs to the next K block.

         a_ptrs += BLOCK_SIZE_K * stride_ak

         b_ptrs += BLOCK_SIZE_K * stride_bk

@@ -100,7 +100,7 @@ def matmul_kernel(

     # while the accumulator is still in FP32!

     if ACTIVATION == "leaky_relu":

         accumulator = leaky_relu(accumulator)

-    c = accumulator.to(tl.float32)

+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

Ok let me try and fix that, broke my env so taking me longer than it should

@danikhan632
Copy link
Contributor Author

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

diff --git a/backend/compiler.py b/backend/compiler.py
index a0965e9..0535430 100644
--- a/backend/compiler.py
+++ b/backend/compiler.py
@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):
     assert len(matches) == 1
     metadata["name"] = matches[0]
     with tempfile.TemporaryDirectory() as tmpdir:
-        src_path = os.path.join(tmpdir, "kernel.ll")
-        dst_path = os.path.join(tmpdir, "kernel.o")
+        src_path = os.path.join(os.getcwd(), "kernel.ll")
+        dst_path = os.path.join(os.getcwd(), "kernel.o")
         Path(src_path).write_text(llir)
         llc_path = _get_llvm_bin_path("llc")
         if  _get_triton_SME_path() == "":
diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py
index caa072c..7ef952d 100644
--- a/python/examples/test_matmul.py
+++ b/python/examples/test_matmul.py
@@ -85,14 +85,14 @@ def matmul_kernel(
     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
     # of fp32 values for higher accuracy.
     # `accumulator` will be converted back to fp16 after the loop.
-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
         # Load the next block of A and B, generate a mask by checking the K dimension.
         # If it is out of bounds, set it to 0.
         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        accumulator += tl.dot(a, b).to(tl.float16)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -100,7 +100,7 @@ def matmul_kernel(
     # while the accumulator is still in FP32!
     if ACTIVATION == "leaky_relu":
         accumulator = leaky_relu(accumulator)
-    c = accumulator.to(tl.float32)
+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

I've had issues recreating this behavior, could you reach out to [email protected] with more details?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants