Skip to content

Commit

Permalink
Merge commit 'ba9f1840' into matthias.bump_torch_mlir5
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Apr 17, 2024
2 parents 30405e7 + ba9f184 commit 9128b04
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 52 deletions.
49 changes: 0 additions & 49 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5757,43 +5757,6 @@ LogicalResult ConvertAtenOp<AtenRepeatInterleaveTensorOp>::matchAndRewrite(
return success();
}

template <typename AtenOpT>
class ConvertAtenOpToTosaCustomOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

ConvertAtenOpToTosaCustomOp(TypeConverter &typeConverter,
MLIRContext *context, std::string opName,
std::string implementedWithOpAttr = "UNDEF")
: OpConversionPattern<AtenOpT>(typeConverter, context),
opName(std::move(opName)),
implementedWithOpAttr(std::move(implementedWithOpAttr)) {}

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Set tosa.custom_op attributes.
// Only identifier needs to be known. Other attributes are not used.
auto *ctx = op->getContext();
auto identifier = StringAttr::get(ctx, opName);
auto implementAttr = StringAttr::get(ctx, implementedWithOpAttr);
auto config = StringAttr::get(ctx, "UNDEF");

rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op,
TypeRange{OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())},
identifier, config, implementAttr, adaptor.getOperands());
return success();
}

private:
std::string opName;
std::string implementedWithOpAttr;
};

class SimplifyAtenIndexTensorWithSliceIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
Expand Down Expand Up @@ -6232,18 +6195,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN

#define INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenOp, opName, implementedWith) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOpToTosaCustomOp<AtenOp>>(typeConverter, context, \
opName, implementedWith);
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenAtan2Op, "math.atan2",
"linalg.generic");
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenSinOp, "math.sin",
"linalg.generic");
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenCosOp, "math.cos",
"linalg.generic");
#undef INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,6 @@
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAsinModule_basic",
"ElementwiseAsinTensorFloatModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
Expand Down Expand Up @@ -1082,7 +1081,6 @@
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic",
Expand Down Expand Up @@ -1169,7 +1167,6 @@
"ElementwiseSeluModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSignModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSqrtIntModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseSubScalarFloatModule_basic",
Expand Down

0 comments on commit 9128b04

Please sign in to comment.