Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of 00efec0b (May 10) (35) #268

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 192 additions & 7 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
return rewriter.notifyMatchFailure(op.getLoc(),
"legacy view lowering diabled");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Expand Down Expand Up @@ -1284,6 +1287,9 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
return rewriter.notifyMatchFailure(op.getLoc(),
"legacy view lowering diabled");
SmallVector<Value> sizes;
if (!getListConstructElements(op.getSize(), sizes))
return op.emitError(
Expand Down Expand Up @@ -1319,12 +1325,16 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
size = convert;
}

// Check we are only inferring one dimension:
Value countPred =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
b.create<cf::AssertOp>(
loc, countPred,
b.getStringAttr("must have at most one inferred (negative) dimension"));
// Check we are only inferring one dimension if not in strict mode. In
// strict mode, there will only ever statically be one inferred dim.
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value countPred =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
b.create<cf::AssertOp>(
loc, countPred,
b.getStringAttr(
"must have at most one inferred (negative) dimension"));
}

// Determine the total size of the inferred dimension and update the
// inferred dimension:
Expand Down Expand Up @@ -1356,6 +1366,165 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
};
} // namespace

namespace {
class ConvertAtenViewOpStrict : public OpConversionPattern<AtenViewOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isAssumingStrictSymbolicShapes(rewriter))
return rewriter.notifyMatchFailure(op.getLoc(),
"not strict symbolic shapes");
SmallVector<Value> sizeValues;
if (!getListConstructElements(op.getSize(), sizeValues))
return op.emitError(
"unimplemented: the tensor size list is not from list construct");

auto loc = op.getLoc();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());

// Handle collapse to 0D.
if (sizeValues.empty()) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, adaptor.getSelf(), ArrayRef<ReassociationIndices>{});
return success();
}

// If there is a static inferred dimension (-1), then we emit a
// flatten/unflatten and let that proceed through its lowering.
// Otherwise, emit a tensor.reshape. Note that this relies on the fact that
// Torch does not allow such an op to have a symbolic inferred dim.
int inferredDim = -1;
bool staticSizes = true;
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
int64_t dim;
if (!matchPattern(sizeValues[i], m_TorchConstantInt(&dim))) {
staticSizes = false;
continue;
}
if (dim == -1) {
inferredDim = i;
break;
}
}

// While it should be illegal to have a view op with fully known sizes
// and a dynamic shape, in reality, torch IR is a bit loosey and
// progressively resolves to this state. There are delicate invariants
// on the ops we produce that require this, so we enforce.
if (staticSizes && !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(loc,
"view cannot be converted with static "
"sizes and a dynamic result type");
}

// Handle inferred dim case.
// TODO: Remove the restriction on staticSizes once flatten/unflatten
// reliably work with multiple dynamic dimensions.
if (inferredDim >= 0 && staticSizes) {
if (!staticSizes) {
return rewriter.notifyMatchFailure(
loc, "view to flatten/unflatten only supported for static sizes");
}
// This is a torch-torch conversion, so only non adapted types are
// involved.
auto selfTy = dyn_cast<ValueTensorType>(op.getSelf().getType());
if (!selfTy || !selfTy.hasSizes())
return failure();

// Work out the 1D flattened type.
int64_t flatDim = 1;
auto selfSizes = selfTy.getSizes();
for (int64_t dim : selfSizes) {
if (dim == kUnknownSize) {
flatDim = kUnknownSize;
break;
}
flatDim *= dim;
}
// Flatten to 1D.
ValueTensorType flatType = rewriter.getType<ValueTensorType>(
ArrayRef<int64_t>{flatDim}, selfTy.getOptionalDtype());
Value dimStart = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value dimEnd = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1));
Value flatSelf = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flatType, op.getSelf(), dimStart, dimEnd);

// Unflatten to requested size.
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, op.getResult().getType(), flatSelf, dimStart, op.getSize());
return success();
}

// Generate output dims, either based on whether there is an inferred dim
// present or all dims are specified.
auto sizeTy = cast<IntegerType>(
typeConverter->convertType(sizeValues.front().getType()));
SmallVector<Value> outputDimValues;
assert(sizeTy && "Type converter did not handle size");
if (inferredDim >= 0) {
// Inferred dim. If the above flatten/unflatten logic ever catches
// everything, this branch can go away entirely.
Value one = rewriter.create<arith::ConstantOp>(
loc, sizeTy, rewriter.getIntegerAttr(sizeTy, 1));
Value sizeProduct = one;
// Multiply the non-inferred target sizes.
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
if (i == inferredDim)
continue;
Value size = sizeValues[i];
Value convertedSize = typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, size);
assert(convertedSize && "Type converter did not handle size");
sizeProduct =
rewriter.create<arith::MulIOp>(loc, sizeProduct, convertedSize);
}

// Multiply the self tensor sizes.
Value selfProduct = one;
for (int i = 0, e = selfTy.getRank(); i < e; ++i) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value dim = rewriter.create<tensor::DimOp>(loc, self, index);
dim = rewriter.create<arith::IndexCastOp>(loc, sizeTy, dim);
selfProduct = rewriter.create<arith::MulIOp>(loc, selfProduct, dim);
}

Value inferredSize =
rewriter.create<arith::DivUIOp>(loc, selfProduct, sizeProduct);
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
if (i == inferredDim) {
outputDimValues.push_back(inferredSize);
} else {
outputDimValues.push_back(typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, sizeValues[i]));
}
}
} else {
// No inferred dim. So output dims are just pass through.
for (Value torchSize : sizeValues) {
outputDimValues.push_back(typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, torchSize));
}
}

// Normal lowering to reshape with fully computed sizes.
auto outputDimsTy = RankedTensorType::get(
outputDimValues.size(), outputDimValues.front().getType());
auto outputDims = rewriter.create<tensor::FromElementsOp>(loc, outputDimsTy,
outputDimValues);
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
op, resultType, adaptor.getSelf(), outputDims);
return success();
}
};
} // namespace

namespace {
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
public:
Expand Down Expand Up @@ -2419,6 +2588,9 @@ SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
// Add some legal ops for torch-torch lowering.
target.addLegalOp<ConstantIntOp>();

MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReflectionPad1dOp>();
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
Expand All @@ -2428,10 +2600,23 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
target.addIllegalOp<AtenUnflattenIntOp>();

// View op sadness: In the future, we only want ConvertAtenViewOpStrict,
// but this requires work upstream to fully generalize reshape handling.
// In the meantime, the analysis based ConvertAtenViewOp tries hard to
// produce expand/collapse shapes, the ConvertAtenViewOpStrict does the
// right thing but cannot be fully supported for dynamic shapes, and
// ConvertAtenViewOpToReshape overly pessimizes and generates a lot of IR
// due to not statically switching between inferred and non-inferred view
// cases. They are ordered by optimiality of the lowerings they generate
// when they are able.
target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/200);
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/300);
patterns.add<ConvertAtenViewOpStrict>(typeConverter, context,
/*benefit=*/200);
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
/*benefit=*/100);

target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>();
Expand Down
5 changes: 5 additions & 0 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
)

Expand Down Expand Up @@ -642,6 +643,10 @@ def import_program(
func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
# Programs imported from FX have strong guarantees. Setting this attribute
# causes various lowerings to be able to emit more efficient code or
# handle more cases. See isAssumingStrictSymbolicShapes().
func_op.attributes["torch.assume_strict_symbolic_shapes"] = UnitAttr.get()
entry_block = Block.create_at_start(func_op.body, ftype.inputs)

node_importer = GraphNodeImporter(
Expand Down
Loading
Loading