Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with f6721e59 (Oct 08) (73) #443

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

// Flips an input tensor based on the values of axis list.
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
SmallVector<int64_t> axis);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
49 changes: 40 additions & 9 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
int64_t &dim,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Expand All @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");

Expand Down Expand Up @@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

// If stride is negative, then flip the input tensor corresponding to that
// dim, update the stride for flipped tensor by multiplying it by -1, and
// update the offset as follows:
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
//
// For example:
// Input = [0, 1, 2, 3, 4, 5]
// stride = [-2], result_shape = [2], offset = [3]
// Result = [3, 1]
// After flipping:
// Input = [5, 4, 3, 2, 1, 0]
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
// Result = [3, 1]

Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
SmallVector<int64_t>{dim});
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, strides[dim], zero);
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
Value resShapeMulStride =
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value flippedOffset =
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
offsets[dim] = rewriter.create<arith::SelectOp>(
loc, isNegativeStride, flippedOffset, offsets[dim]);

input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
flippedInput, input);

SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Expand Down Expand Up @@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
SmallVector<Value> resultShape, offsets, strides;
int64_t dim;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure();
}

Expand Down
39 changes: 1 addition & 38 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType =
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

SmallVector<int64_t> axis;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
Expand All @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
}
}

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);

SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, self.getType(), self, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(
loc, dims[flipDim], indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);

return success();
}
};
Expand Down
41 changes: 41 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
.getResult(0);
return success();
}

// Flips an input tensor based on the values of axis list.
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> axis) {
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();

// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);

Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);

SmallVector<utils::IteratorType> iteratorTypes(selfRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(), input, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
return flipped;
}
Loading