Skip to content

Commit

Permalink
Fix build failure in issues #178 (#187)
Browse files Browse the repository at this point in the history
Fixing below compilation errors in
include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:


/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:839:68:
error: no member named 'getFile' in 'mlir::triton::AssertOp'
llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(),
                                                                  ~~ ^

/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:26:
error: no member named 'getLine' in 'mlir::triton::AssertOp'
                        op.getLine(), op.getFunc(), op.getMessage());
                        ~~ ^

/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:40:
error: no member named 'getFunc' in 'mlir::triton::AssertOp'
                        op.getLine(), op.getFunc(), op.getMessage());
                                      ~~ ^
  3 errors generated.

This fix builds with triton @ab07e5472bcb414a0c8dd7ecab80f84370c4894e,
and llvm @cfd3289a1f9a87e220737a634904a886a82d424a.

---------

Co-authored-by: Zhaoshi Zheng <[email protected]>
  • Loading branch information
zhaoshiz and Zhaoshi Zheng authored Jan 27, 2025
1 parent 89286b4 commit 560c064
Show file tree
Hide file tree
Showing 25 changed files with 104 additions and 49 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ set(TRITON_SHARED_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${Python3_INCLUDE_DIR})
include_directories(${pybind11_INCLUDE_DIR})

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
Expand Down
4 changes: 4 additions & 0 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ class CPUOptions:
extern_libs = None
cluster_dims: tuple = (1, 1, 1)
shared: bool = False
# Disable FP8 here since this is a sample CPU backend.
# Target specific backends can eanble it with supported types.
supported_fp8_dtypes: Tuple[str] = ()
allow_fp8e4nv: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
sanitize_overflow: bool = True

def __post_init__(self):
pass
Expand Down
17 changes: 15 additions & 2 deletions backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"
if ty == "constexpr":
return "PyObject*"
return {
"i1": "int32_t",
"i8": "int8_t",
Expand All @@ -37,11 +39,14 @@ def _ty_to_cpp(ty):
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return _ty_to_cpp(ty)

def _format_of(ty):
return {
"PyObject*": "O",
"constexpr": "O",
"float": "f",
"double": "d",
"long": "l",
Expand All @@ -61,10 +66,10 @@ def _generate_launcher(constants, signature, kernel_name):
format = "iiiOOOO" + args_format
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''

kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)
kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls += ', ' if kernel_arg_decls else ''

kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)
kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr")
kernel_parameters += ', ' if kernel_parameters else ''

return f"""
Expand Down Expand Up @@ -347,6 +352,10 @@ def __init__(self):
def is_active():
return False

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench

def get_device_capability(self):
return ("cpu", 0)

Expand All @@ -365,5 +374,9 @@ def set_current_device(self, device):
def get_current_target(self):
return GPUTarget("cpu", 0, 0)

def get_active_torch_device(self):
import torch
return torch.device("cpu")

def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,7 @@ struct AssertConverter : public OpConversionPattern<triton::AssertOp> {
}

auto assertMessage =
llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(),
op.getLine(), op.getFunc(), op.getMessage());
llvm::formatv("Assertion `{0}` failed", op.getMessage());
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
assertMessage.str());

Expand Down
11 changes: 6 additions & 5 deletions lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ struct MakeTensorPtrConverter
/* result shape */
SmallVector<int64_t>{

// Row stays the same
resultShape[0],
// Row stays the same, but mlir doesn't allow this anymore. Put
// dynamic.
ShapedType::kDynamic,

// Column is dynamic, in most cases, this
// should be the same as the original column.
Expand Down Expand Up @@ -286,9 +287,9 @@ struct MakeTensorPtrConverter
// around.
ShapedType::kDynamic,

// Col stays the same.
resultShape[1],
});
// Col stays the same, which is resultShape[1], but mlir doesn't
// allow this anymore. So we put dynamic instead.
ShapedType::kDynamic});

Value rowSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(op.getSizes()[0]));
Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class LoopTypeConverter : public TypeConverter {
// reinterpret_cast.
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
auto reinterpretCast =
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
if (!reinterpretCast) {
Expand All @@ -99,14 +99,14 @@ class LoopTypeConverter : public TypeConverter {

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addArgumentMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand All @@ -123,7 +123,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TritonFunctionSignatureConverter : public TypeConverter {

auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TritonToStructuredPass
RewritePatternSet patterns(&getContext());

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->1 type conversion here, where a triton pointer type
Expand Down Expand Up @@ -145,10 +145,10 @@ class TritonToStructuredPass
// Compute the target materialization, given a value with the pointer type,
// convert that value to a tuple type.
converter.addTargetMaterialization(
[](OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) -> std::optional<SmallVector<Value>> {
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) -> SmallVector<Value> {
return builder
.create<UnrealizedConversionCastOp>(loc, resultTypes, input)
.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs.front())
->getResults();
});

Expand All @@ -172,7 +172,7 @@ class TritonToStructuredPass
auto moduleOp = getOperation();

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->N type conversion here, where a pointer tuple type
Expand Down Expand Up @@ -208,10 +208,10 @@ class TritonToStructuredPass
// At the end of pointer analysis, we will use the PtrState to create the
// correct offsets, strides, and remove these ops.
converter.addTargetMaterialization([](OpBuilder &builder,
TypeRange resultTypes, Value input,
TypeRange resultTypes, ValueRange inputs,
Location loc) {
auto placeholder = builder.create<tts::GetStructuredStateOp>(
loc, input.getDefiningOp()->getOperand(0));
loc, inputs.front().getDefiningOp()->getOperand(0));
assert(llvm::equal(placeholder.getResultTypes(), resultTypes));
return placeholder.getResults();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down
24 changes: 22 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,32 @@ def device(request):
"test_if",
"test_if_call",
"test_convert2d",
"test_convertmma2mma",
"test_dot_max_num_imprecise_acc",
"test_propagate_nan",
"test_clamp_symmetric",
"test_temp_var_in_loop",
"test_math_extern"
"test_math_extern",
# attribute 'launch_cooperative_grid' not supported
"test_load_scope_sem_coop_grid_cta_one",
# fp8 support on CPUs is unclear
"test_scaled_dot",
# triton-shared-opt failures:
# PtrAnalysis: encountered addptr operand produced by an unsupported operation
"test_chained_reductions",
# failed to legalize unresolved materialization
"test_constexpr_if_return",
"test_unroll_attr",
# Dialect `ub' not found for custom op 'ub.poison'
"test_poison_return",
# tt.gather not supported yet
"test_gather",
"test_gather_warp_shuffle",
# device 'cpu' does not have 'index
"test_zero_strided_tensors",
# hard-coded with 'ttg' attributes
"test_convert_mma2mma",
"test_local_load_store",
"test_local_load_store_mma"
}

# probably different version of MLIR on the nightly build machine is complaining
Expand Down
8 changes: 6 additions & 2 deletions python/examples/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def test(device):
# TODO: need to check some conditions otherwise the code below does not make any difference for the test
src = triton.compiler.ASTSource(
fn=reduce_kernel_2d,
signature="*fp32,*fp32,i32,i32",
constants={"BLOCK_SIZE": 32}
signature={"x_ptr": "*fp32",
"output_ptr": "*fp32",
"stride": "i32",
"n_elements": "i32",
"BLOCK_SIZE": "constexpr"},
constexprs={"BLOCK_SIZE": 32}
)
ret = triton.compile(
src,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ module {
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>
Expand Down
8 changes: 8 additions & 0 deletions test/Conversion/StructuredToMemref/get_num_programs.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// XFAIL: *
// Note: PtrAnalysis pass can create a tts.makeptr for below pattern:
// %3 = arith.constant 0 : index
// %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
// %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr<i32>>, tensor<1xi32>
// But not if creating constant 0 and add it to a pointer is optimized away:
// %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
// A patch that rewrites tt.splat in such case will be sent separately.
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s

module {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/StructuredToMemref/triton_assert.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ tt.func public @assert_lol(%arg0: i32) {
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
%1 = tt.splat %0 : i1 -> tensor<1xi1>
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
tt.assert %1, "lol" : tensor<1xi1>
tt.return
}

Expand All @@ -12,6 +12,6 @@ tt.func public @assert_lol(%arg0: i32) {
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
// CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed"
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
// CHECK: return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ module {
// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_14_]], [[CST_4_]] : index
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_5_]] : index
// CHECK: [[VAR_18_:%.+]] = arith.subi [[VAR_17_]], [[VAR_14_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_19_:%.+]] = arith.subi [[CST_4_]], [[VAR_18_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_21_:%.+]] = arith.subi [[CST_4_]], [[VAR_20_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_20_]]{{.}} [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>>
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/StructuredToMemref/wraparound_stacked.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ module {
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_3_]], [[VAR_12_]] : index
// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index
// CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x4xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[CST_4_]], [[VAR_15_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x4xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_18_:%.+]] = arith.subi [[CST_4_]], [[VAR_17_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_17_]], 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1]>>
Expand Down
Loading

0 comments on commit 560c064

Please sign in to comment.