Skip to content

Commit

Permalink
Implement aten.reflection_pad2d lowering to linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-h authored and stellaraccident committed Jan 11, 2024
1 parent aee1fca commit 0860c41
Show file tree
Hide file tree
Showing 8 changed files with 553 additions and 0 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
}];
}

def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenPadOp : Torch_Op<"aten.pad", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
290 changes: 290 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,294 @@ class ConvertAtenReflectionPad1dOp
};
}

namespace {

// Lower the aten.reflection.pad_2d operator into a sequence of
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
// operations.

// To understand the lowering, consider this pytorch example:
//
// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]])
// >>> t
// tensor([[[1., 2., 3.],
// [4., 5., 6.],
// [7., 8., 9.]]])
// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2])
// tensor([[[5., 4., 5., 6., 5., 4.],
// [2., 1., 2., 3., 2., 1.],
// [5., 4., 5., 6., 5., 4.],
// [8., 7., 8., 9., 8., 7.],
// [5., 4., 5., 6., 5., 4.],
// [2., 1., 2., 3., 2., 1.]]])
//
// The result can be subdivided into "tiles" corresponding to either
// the input tensor (in the center) or slices of the input tensor
// whose width and height is determined by the padding sizes and which
// are reflected through the side of the central input tensor that
// they touch.
// In the example above, the tiles are:
// top left: [[5]]
// top center: [[4,5,6]]
// top right: [[5,4]]
// center left [[2,1],[5,4],[8,7]]
// center: copy of the input tensor
// center right: [[2,1],[5,4],[8,7]]
// bottom left: [[5,4],[2,1]]
// center bottom: [[2,3,2]]
// center right: [[2,1]]
//
// The lowering uses a tensor.extract_slice operation to create each tile,
// a linalg.generic for the reflection, and a tensor.insert_slice to
// insert the tile in the resulting tensor.
class ConvertAtenReflectionPad2dOp
: public OpConversionPattern<AtenReflectionPad2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

SmallVector<int64_t> padInts;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");

Location loc = op.getLoc();
// Some generic helper functions for creating arithmetic operations.
auto createAdd = [&](Value x, Value y) {
return rewriter.create<arith::AddIOp>(loc, x, y);
};

auto createAdds = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createAdd);
};

auto createSub = [&](Value x, Value y) {
return rewriter.create<arith::SubIOp>(loc, x, y);
};

auto createSubs = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createSub);
};

// Enums for specifying the coordinates of a tile. An "h" prefix
// is used to stand for "horizontal" and "v" for "vertical"
// throughout.
enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 };
enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 };

// Helper functions for obtaining information about the operator's
// padding arguments.
auto getHPadArgument = [&](PadHLoc l) {
assert(l < HCENTER);
return padInts[l];
};

auto getVPadArgument = [&](PadVLoc l) {
assert(l < VCENTER);
return padInts[2 + l];
};

auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) {
if (!(h == HCENTER || getHPadArgument(h) > 0))
return false;
if (!(v == VCENTER || getVPadArgument(v) > 0))
return false;

return true;
};

Value input = adaptor.getSelf();
MLIRContext *context = rewriter.getContext();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
auto outputType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
unsigned numDims = inputType.getRank();

assert(numDims >= 2 && "Not enough input dimensions");

SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
int64_t hDim = numDims - 1;
int64_t vDim = numDims - 2;
Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim];

assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
"Left padding too large");
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
"Right padding too large");
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
"Top padding too large");
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large");

Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);

Value tileWidth[3];
tileWidth[HCENTER] = hDimSize;
for (auto h : {LEFT, RIGHT})
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);

Value tileHeight[3];
tileHeight[VCENTER] = vDimSize;
for (auto v : {TOP, BOTTOM})
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);

// Helper to reflect/reverse the i-th dimension of an affine map
// without symbols. This only works if applied on a tensor
// for which the corresponding dimension has a statically
// known size which is good enough since we only apply
// it to reflect the padding slices.
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
int64_t size) {
AffineExpr d = map.getResult(i);
return map.replace(d, size - d - 1, numDims, 0);
};

// Create output shape and tensor
SmallVector<Value> resultShape{inputShape};
resultShape[vDim] =
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
resultShape[hDim] =
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});

Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
inputType.getElementType());

// Construction of the tiles

// Example: central left tile
//
// Let m the width of the left padding as returned by getHPadargument(LEFT)
// and n the size of the input tensor's "horizontal" dimension, i.e.
// hDimSize. Assume that the subtensor of the input tensor in the relevant
// (i.e. last two) dimensions is:
//
// x_1,1 x_1,2 ... x_1,m
// x_2,1 x_2,2 ... x_2,m
// .
// .
// .
// x_n,1 x_n,2 ... x_n,m
//
// The padding tile consists of the columns 2, ..., m + 1
// of the input in reverse order. The first column gets
// skipped because this is the column through which the
// reflection happens.
//
// x_1,m x_1,m-1 ... x_1,2
// x_2,m x_1,m-1 ... x_2,2
// .
// .
// .
// x_n,m x_n,m-1 ... x_n,2
//
// The tile will be inserted to the left of the copy of the input tensor
// in the output tensor, i.e. with horizontal offset 0.
// The top padding determines the vertical offset.

// Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
// two sides, i.e. their columns and rows must be reversed.

// Setup information about the tiles

// Compute the offsets for extracting the slice from the
// input. We need to skip the row or column through which
// the tile should be reflected, if any (none for the center tile).
Value extractHOffset[3];
extractHOffset[LEFT] = one;
extractHOffset[HCENTER] = zero;
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});

Value extractVOffset[3];
extractVOffset[TOP] = one;
extractVOffset[VCENTER] = zero;
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});

// Compute the horizontal and vertical offsets for inserting
// the tiles in the resultTensor.
Value insertHOffset[3];
insertHOffset[LEFT] = zero;
insertHOffset[HCENTER] = tileWidth[LEFT];
insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]);

Value insertVOffset[3];
insertVOffset[TOP] = zero;
insertVOffset[VCENTER] = tileHeight[TOP];
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);

auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };

SmallVector<utils::IteratorType> iteratorTypes{
numDims, utils::IteratorType::parallel};
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
SmallVector<Value> allOneStrides(numDims, one);

auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractShape{inputShape};
extractShape[hDim] = tileWidth[horizontalPos];
extractShape[vDim] = tileHeight[verticalPos];

SmallVector<Value> extractOffsets(numDims, zero);
extractOffsets[hDim] = extractHOffset[horizontalPos];
extractOffsets[vDim] = extractVOffset[verticalPos];

Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, extractShape, allOneStrides);

// Reverse the tile along the horizontal, vertical, or both
// dimensions.
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
if (shouldHReflect(horizontalPos)) {
inputMap =
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
}
if (shouldVReflect(verticalPos)) {
inputMap =
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
}

tile = rewriter
.create<linalg::GenericOp>(
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);

// Insert the tile in the resultTensor.
SmallVector<Value> insertOffsets(numDims, zero);
insertOffsets[hDim] = insertHOffset[horizontalPos];
insertOffsets[vDim] = insertVOffset[verticalPos];

resultTensor = rewriter.create<tensor::InsertSliceOp>(
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
};

for (auto v : {TOP, BOTTOM, VCENTER})
for (auto h : {LEFT, RIGHT, HCENTER})
if (shouldCreateTile(v, h))
createTile(v, h);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);

return success();
}
};
} // namespace

namespace {
class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
Expand Down Expand Up @@ -1552,6 +1840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReflectionPad1dOp>();
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
target.addIllegalOp<AtenReflectionPad2dOp>();
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
target.addIllegalOp<AtenViewOp>();
Expand Down
Loading

0 comments on commit 0860c41

Please sign in to comment.