Skip to content

Commit

Permalink
[DLTI] Enable obtaining tile size from module-associated attr
Browse files Browse the repository at this point in the history
By way of demonstration, this change only enables obtaining the tile size in the
-tile-consumer-and-fuse-producers pass.
  • Loading branch information
rolfmorel committed Aug 8, 2024
1 parent ce33de5 commit c730cce
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
27 changes: 26 additions & 1 deletion lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,32 @@ static FailureOr<scf::SCFTileAndFuseResult> fuseWithEltwise(
// Trivial tile selection. If the dimension is statically known, it perfectly
// divides the tile, and we have enough iterations return a default of 32.
static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) {
const int64_t tile = 32;
int64_t tile = 32;

// Check if a tile size hint is associated to the IR via DLTI.
auto deriveFromDLTI = [&](ModuleOp moduleOp) {
if (!moduleOp)
return;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return;
auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return;
auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(tileSizeId);
if (!entry)
return;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
tile = intAttr.getInt();
// TODO: might want to print a warning if tile_size exists as a key but the
// associated attribute has an unexpected type.
};
deriveFromDLTI(linalgOp->getParentOfType<mlir::ModuleOp>());

SmallVector<int64_t, 4> loopsRange = linalgOp.getStaticLoopRanges();
if (loopsRange[dim] == ShapedType::kDynamic)
return tile;
Expand Down
55 changes: 55 additions & 0 deletions test/Passes/tile-and-fuse-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,58 @@ func.func @contraction(%arg0: tensor<16x1xf32>, %arg1: tensor<1x32xf32>) -> tens

// CHECK-LABEL: contraction
// CHECK-NOT: scf.for

// -----

// CHECK-LABEL: dlti_tile_size_32
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>> } {
func.func @dlti_tile_size_32(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK: scf.for %{{.+}} step %[[C32]]
// CHECK-NEXT: scf.for %{{.+}} step %[[C32]]
// CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<32x2048xf32>, tensor<2048x32xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<32x32xf32>)
%0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>)
outs(%arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32>
return %0 : tensor<2048x2048xf32>
}
}

// -----

// CHECK-LABEL: dlti_tile_size_64
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 64 : i32>>> } {
func.func @dlti_tile_size_64(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
// CHECK: scf.for %{{.+}} step %[[C64]]
// CHECK-NEXT: scf.for %{{.+}} step %[[C64]]
// CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<64x2048xf32>, tensor<2048x64xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<64x64xf32>)
%0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>)
outs(%arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32>
return %0 : tensor<2048x2048xf32>
}
}


// -----

// CHECK-LABEL: dlti_tile_size_16
module attributes { dlti.target_system_spec = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 16 : i32>>> } {
func.func @dlti_tile_size_16(%arg0: tensor<2048x2048xf32>, %arg1: tensor<2048x2048xf32>, %arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32> {
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK: scf.for %{{.+}} step %[[C16]]
// CHECK-NEXT: scf.for %{{.+}} step %[[C16]]
// CHECK: %{{.+}} = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<16x2048xf32>, tensor<2048x16xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<16x16xf32>)
%0 = linalg.matmul ins(%arg0, %arg1: tensor<2048x2048xf32>, tensor<2048x2048xf32>)
outs(%arg2: tensor<2048x2048xf32>)
-> tensor<2048x2048xf32>
return %0 : tensor<2048x2048xf32>
}
}

0 comments on commit c730cce

Please sign in to comment.