Skip to content

Commit

Permalink
[TorchToLinalg] Support torch.isclose lower to linalg (llvm#3631)
Browse files Browse the repository at this point in the history
  • Loading branch information
lingzhiz1998 authored Aug 21, 2024
1 parent a24114e commit 7f886cc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
46 changes: 44 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return value;
}

if (auto isClose = dyn_cast<AtenIscloseOp>(op)) {
double rtol, atol;
bool equalNan;
if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) {
isClose.emitError("rtol must be a scalar constant");
return nullptr;
}
if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) {
isClose.emitError("atol must be a scalar constant");
return nullptr;
}
if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) {
isClose.emitError("unimplemented: equal_nan is expected to be false");
return nullptr;
}
auto lhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[0].getType());
auto rhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[1].getType());
if (!lhsType || !rhsType) {
isClose.emitError("unimplemented: only FP element type is supported");
return nullptr;
}
// Choose the widest float type as compute type.
auto computeType =
lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType;
computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type();
auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType);
auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType);
// Reference to the definition of torch.isclose:
// ∣input − other∣ <= atol + rtol × ∣other∣
auto diff = b.create<arith::SubFOp>(loc, computeType, cvtArg0, cvtArg1);
auto absDiff = b.create<math::AbsFOp>(loc, computeType, diff);
auto cstRtol =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, rtol));
auto absOther = b.create<math::AbsFOp>(loc, computeType, cvtArg1);
auto mul = b.create<arith::MulFOp>(loc, computeType, cstRtol, absOther);
auto cstAtol =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, atol));
auto threshold = b.create<arith::AddFOp>(loc, computeType, cstAtol, mul);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, absDiff,
threshold);
}

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
Expand Down Expand Up @@ -1564,7 +1606,7 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>(op))
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -3256,7 +3298,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
2 changes: 0 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())

LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
# these interpolate tests are added specifically to test onnx.Resize.
"InterpolateDynamicModule_sizes_bilinear",
Expand Down

0 comments on commit 7f886cc

Please sign in to comment.