Skip to content

Commit

Permalink
Add aten.gt.Tensor op
Browse files Browse the repository at this point in the history
`aten.gt.Tensor` op has been added in torch dialect and the
lowering of the op has been done to the linalg dialect.

Signed-off-by: Prashant Kumar <[email protected]>
  • Loading branch information
Prashant Kumar committed Dec 12, 2021
1 parent a778f99 commit 528354d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
36 changes: 36 additions & 0 deletions e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,42 @@ def forward(self, x):
def ElementwiseGtScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))

class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.gt(x, y)


@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))

class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.gt(x, y)


@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,)))

# ==============================================================================


Expand Down
30 changes: 28 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();

// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype)
gtTensor.emitError("unimplemented: different lhs and rhs dtype");

Type elementalType =
gtTensor.self().getType().cast<BaseTensorType>().getDtype();

if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], payloadArgs[1]);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], payloadArgs[1]);
}
gtTensor.emitError("unimplemented: dtype isn't supported.");
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType())
Expand Down Expand Up @@ -2070,7 +2096,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
AtenCeilOp>(op))
AtenCeilOp, AtenGtTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -3640,7 +3666,7 @@ class ConvertTorchToLinalg
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp>();
AtenWhereSelfOp, AtenGtTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (isa<AtenGtTensorOp>(op)) {
return visitBinaryBroadcastingComparisonOp(op, operands);
} else if (auto whereSelf = llvm::dyn_cast<AtenWhereSelfOp>(op)) {
return visitAtenWhereSelfOp(whereSelf, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
Expand Down Expand Up @@ -505,6 +507,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitBinaryBroadcastingOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitBinaryBroadcastingComparisonOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenWhereSelfOp(AtenWhereSelfOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
Expand Down Expand Up @@ -884,6 +888,21 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitBinaryBroadcastingComparisonOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto lhs = operands[0]->getValue();
auto rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
if (lhs.hasSizes && rhs.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(std::max(lhs.sizes.size(), rhs.sizes.size()),
kUnknownSize);
}
knowledge.dtype = IntegerType::get(op->getContext(), 1);
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenWhereSelfOp(
AtenWhereSelfOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto condition = operands[0]->getValue();
Expand Down

0 comments on commit 528354d

Please sign in to comment.