diff --git a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp index b7804bfee..bc46c323d 100644 --- a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp +++ b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp @@ -449,7 +449,32 @@ static FailureOr fuseWithEltwise( // Trivial tile selection. If the dimension is statically known, it perfectly // divides the tile, and we have enough iterations return a default of 32. static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) { - const int64_t tile = 32; + int64_t tile = 32; + + // Check if a tile size hint is associated to the IR via DLTI. + auto deriveFromDLTI = [&](ModuleOp moduleOp) { + if (!moduleOp) + return; + TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec(); + if (!sysSpec) + return; + auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU"); + auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); + if (!deviceSpec) + return; + auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size"); + DataLayoutEntryInterface entry = + (*deviceSpec).getSpecForIdentifier(tileSizeId); + if (!entry) + return; + Attribute value = entry.getValue(); + if (auto intAttr = llvm::dyn_cast(value)) + tile = intAttr.getInt(); + // TODO: might want to print a warning if tile_size exists as a key but the + // associated attribute has an unexpected type. + }; + deriveFromDLTI(linalgOp->getParentOfType()); + SmallVector loopsRange = linalgOp.getStaticLoopRanges(); if (loopsRange[dim] == ShapedType::kDynamic) return tile; diff --git a/test/Passes/tile-and-fuse-default.mlir b/test/Passes/tile-and-fuse-default.mlir index 3a1bf133c..d1e5f1079 100644 --- a/test/Passes/tile-and-fuse-default.mlir +++ b/test/Passes/tile-and-fuse-default.mlir @@ -729,3 +729,58 @@ func.func @contraction(%arg0: tensor<16x1xf32>, %arg1: tensor<1x32xf32>) -> tens // CHECK-LABEL: contraction // CHECK-NOT: scf.for + +// ----- + +// CHECK-LABEL: dlti_tile_size_32 +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } { + func.func @dlti_tile_size_32(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> { + // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + // CHECK: scf.for %{{.+}} step %[[C32]] + // CHECK-NEXT: scf.for %{{.+}} step %[[C32]] + // CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<32x2048xf32>, tensor<2048x32xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<32x32xf32>) + %0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>) + outs(%arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> + return %0 : tensor<2048x2048xf32> + } +} + +// ----- + +// CHECK-LABEL: dlti_tile_size_64 +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } { + func.func @dlti_tile_size_64(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> { + // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + // CHECK: scf.for %{{.+}} step %[[C64]] + // CHECK-NEXT: scf.for %{{.+}} step %[[C64]] + // CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<64x2048xf32>, tensor<2048x64xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<64x64xf32>) + %0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>) + outs(%arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> + return %0 : tensor<2048x2048xf32> + } +} + + +// ----- + +// CHECK-LABEL: dlti_tile_size_16 +module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } { + func.func @dlti_tile_size_16(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> { + // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + // CHECK: scf.for %{{.+}} step %[[C16]] + // CHECK-NEXT: scf.for %{{.+}} step %[[C16]] + // CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<16x2048xf32>, tensor<2048x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<16x16xf32>) + %0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>) + outs(%arg2: tensor<2048x2048xf32>) + -> tensor<2048x2048xf32> + return %0 : tensor<2048x2048xf32> + } +}