Skip to content

Commit

Permalink
Merge pull request #155 from Xilinx/tiagot.remove_tosa_custom_op_lega…
Browse files Browse the repository at this point in the history
…lizations

feat: remove TorchToTosa legalizations that lower to tosa.custom.
  • Loading branch information
ttjost authored Mar 25, 2024
2 parents 2741640 + de7be82 commit ba9f184
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 55 deletions.
49 changes: 0 additions & 49 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5686,43 +5686,6 @@ LogicalResult ConvertAtenOp<AtenRepeatInterleaveTensorOp>::matchAndRewrite(
return success();
}

template <typename AtenOpT>
class ConvertAtenOpToTosaCustomOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

ConvertAtenOpToTosaCustomOp(TypeConverter &typeConverter,
MLIRContext *context, std::string opName,
std::string implementedWithOpAttr = "UNDEF")
: OpConversionPattern<AtenOpT>(typeConverter, context),
opName(std::move(opName)),
implementedWithOpAttr(std::move(implementedWithOpAttr)) {}

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Set tosa.custom_op attributes.
// Only identifier needs to be known. Other attributes are not used.
auto *ctx = op->getContext();
auto identifier = StringAttr::get(ctx, opName);
auto implementAttr = StringAttr::get(ctx, implementedWithOpAttr);
auto config = StringAttr::get(ctx, "UNDEF");

rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op,
TypeRange{OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())},
identifier, config, implementAttr, adaptor.getOperands());
return success();
}

private:
std::string opName;
std::string implementedWithOpAttr;
};

class SimplifyAtenIndexTensorWithSliceIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
Expand Down Expand Up @@ -6159,18 +6122,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN

#define INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenOp, opName, implementedWith) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOpToTosaCustomOp<AtenOp>>(typeConverter, context, \
opName, implementedWith);
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenAtan2Op, "math.atan2",
"linalg.generic");
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenSinOp, "math.sin",
"linalg.generic");
INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenCosOp, "math.cos",
"linalg.generic");
#undef INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
6 changes: 0 additions & 6 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,15 +1102,11 @@
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseAbsModule_basic",
"ElementwiseAcosModule_basic",
"ElementwiseAcosTensorFloatModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAsinTensorFloatModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
Expand Down Expand Up @@ -1147,7 +1143,6 @@
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
Expand Down Expand Up @@ -1227,7 +1222,6 @@
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSignModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSqrtIntModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseSubScalarFloatModule_basic",
Expand Down

0 comments on commit ba9f184

Please sign in to comment.