diff --git a/lib/Conversion/StablehloToEmitC/StablehloRegionOpsToEmitC.cpp b/lib/Conversion/StablehloToEmitC/StablehloRegionOpsToEmitC.cpp index cb78e22f..ff396442 100644 --- a/lib/Conversion/StablehloToEmitC/StablehloRegionOpsToEmitC.cpp +++ b/lib/Conversion/StablehloToEmitC/StablehloRegionOpsToEmitC.cpp @@ -27,15 +27,6 @@ using namespace mlir::emitc; namespace { /// Common functions. -/// Adopted from mlir-hlo. -DenseIntElementsAttr i64ElementsAttr(int64_t value, size_t count, - MLIRContext *ctx) { - RankedTensorType ty = RankedTensorType::get({static_cast(count)}, - IntegerType::get(ctx, 64)); - SmallVector values(count, value); - return DenseIntElementsAttr::get(ty, values); -} - SmallVector indexSequence(int64_t n, MLIRContext *ctx) { return llvm::to_vector<2>( llvm::map_range(llvm::seq(0, n), [&ctx](int64_t i) -> Attribute { @@ -202,14 +193,14 @@ struct ConvertStablehloRegionOpsToEmitCPass size_t dim = op.getResult(0).getType().cast().getRank(); arguments.push_back(op.getWindowDimensions()); - arguments.push_back( - op.getWindowStrides().value_or(i64ElementsAttr(1, dim, ctx))); - arguments.push_back( - op.getBaseDilations().value_or(i64ElementsAttr(1, dim, ctx))); - arguments.push_back( - op.getBaseDilations().value_or(i64ElementsAttr(1, dim, ctx))); - arguments.push_back( - op.getPadding().value_or(i64ElementsAttr(0, 2 * dim, ctx))); + arguments.push_back(op.getWindowStrides().value_or( + builder.getI64TensorAttr(SmallVector(dim, 1)))); + arguments.push_back(op.getBaseDilations().value_or( + builder.getI64TensorAttr(SmallVector(dim, 1)))); + arguments.push_back(op.getBaseDilations().value_or( + builder.getI64TensorAttr(SmallVector(dim, 1)))); + arguments.push_back(op.getPadding().value_or( + builder.getI64TensorAttr(SmallVector(2 * dim, 0)))); arguments.push_back(SymbolRefAttr::get(ctx, funcOp.getName())); ArrayAttr args = ArrayAttr::get(ctx, arguments); diff --git a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp index 4f2d3d38..549f2895 100644 --- a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp +++ b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp @@ -27,23 +27,6 @@ using namespace mlir::emitc; namespace { /// Common functions. -/// Adopted from mlir-hlo. -DenseIntElementsAttr i64ElementsAttr(int64_t value, size_t count, - MLIRContext *ctx) { - RankedTensorType ty = RankedTensorType::get({static_cast(count)}, - IntegerType::get(ctx, 64)); - SmallVector values(count, value); - return DenseIntElementsAttr::get(ty, values); -} - -DenseIntElementsAttr getI64ElementsAttr(const ArrayRef values, - MLIRContext *ctx) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, IntegerType::get(ctx, 64)); - - return DenseIntElementsAttr::get(ty, values); -} - SmallVector indexSequence(int64_t n, MLIRContext *ctx) { return llvm::to_vector<2>( llvm::map_range(llvm::seq(0, n), [&ctx](int64_t i) -> Attribute { @@ -172,7 +155,6 @@ class ConvOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(stablehlo::ConvolutionOp convOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto *ctx = convOp.getContext(); StringRef funcName = "emitc::stablehlo::convolution"; StringAttr callee = rewriter.getStringAttr(funcName); @@ -201,14 +183,14 @@ class ConvOpConversion : public OpConversionPattern { convOp.getDimensionNumbers().getOutputSpatialDimensions())); arguments.push_back(convOp.getFeatureGroupCountAttr()); - arguments.push_back( - convOp.getPadding().value_or(i64ElementsAttr(0, 2, ctx))); - arguments.push_back( - convOp.getLhsDilation().value_or(i64ElementsAttr(1, 2, ctx))); - arguments.push_back( - convOp.getRhsDilation().value_or(i64ElementsAttr(1, 2, ctx))); - arguments.push_back( - convOp.getWindowStrides().value_or(i64ElementsAttr(1, 2, ctx))); + arguments.push_back(convOp.getPadding().value_or( + rewriter.getI64TensorAttr(SmallVector(2, 0)))); + arguments.push_back(convOp.getLhsDilation().value_or( + rewriter.getI64TensorAttr(SmallVector(2, 1)))); + arguments.push_back(convOp.getRhsDilation().value_or( + rewriter.getI64TensorAttr(SmallVector(2, 1)))); + arguments.push_back(convOp.getWindowStrides().value_or( + rewriter.getI64TensorAttr(SmallVector(2, 1)))); ArrayAttr args = rewriter.getArrayAttr(arguments); ArrayAttr templateArgs = @@ -318,12 +300,11 @@ class SliceOpConversion : public OpConversionPattern { SmallVector arguments = indexSequence(adaptor.getOperands().size(), sliceOp.getContext()); - arguments.push_back(getI64ElementsAttr(sliceOp.getStartIndicesAttr(), - sliceOp.getContext())); - arguments.push_back(getI64ElementsAttr(sliceOp.getLimitIndicesAttr(), - sliceOp.getContext())); arguments.push_back( - getI64ElementsAttr(sliceOp.getStridesAttr(), sliceOp.getContext())); + rewriter.getI64TensorAttr(sliceOp.getStartIndicesAttr())); + arguments.push_back( + rewriter.getI64TensorAttr(sliceOp.getLimitIndicesAttr())); + arguments.push_back(rewriter.getI64TensorAttr(sliceOp.getStridesAttr())); ArrayAttr args = rewriter.getArrayAttr(arguments); @@ -357,8 +338,8 @@ class DynamicSliceOpConversion SmallVector arguments = indexSequence( adaptor.getOperands().size(), dynamicSliceOp.getContext()); - arguments.push_back(getI64ElementsAttr(dynamicSliceOp.getSliceSizesAttr(), - dynamicSliceOp.getContext())); + arguments.push_back( + rewriter.getI64TensorAttr(dynamicSliceOp.getSliceSizesAttr())); ArrayAttr args = rewriter.getArrayAttr(arguments); @@ -423,11 +404,11 @@ class PadOpConversion : public OpConversionPattern { indexSequence(adaptor.getOperands().size(), padOp.getContext()); arguments.push_back( - getI64ElementsAttr(padOp.getEdgePaddingLowAttr(), padOp.getContext())); + rewriter.getI64TensorAttr(padOp.getEdgePaddingLowAttr())); arguments.push_back( - getI64ElementsAttr(padOp.getEdgePaddingHighAttr(), padOp.getContext())); + rewriter.getI64TensorAttr(padOp.getEdgePaddingHighAttr())); arguments.push_back( - getI64ElementsAttr(padOp.getInteriorPaddingAttr(), padOp.getContext())); + rewriter.getI64TensorAttr(padOp.getInteriorPaddingAttr())); ArrayAttr args = rewriter.getArrayAttr(arguments); @@ -460,8 +441,8 @@ class TransposeOpConversion SmallVector arguments = indexSequence(adaptor.getOperands().size(), transposeOp.getContext()); - arguments.push_back(getI64ElementsAttr(transposeOp.getPermutationAttr(), - transposeOp.getContext())); + arguments.push_back( + rewriter.getI64TensorAttr(transposeOp.getPermutationAttr())); ArrayAttr args = rewriter.getArrayAttr(arguments); Type resultType = transposeOp.getResult().getType(); diff --git a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp index 600a74ba..a5fc5e0d 100644 --- a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp +++ b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp @@ -27,14 +27,6 @@ using namespace mlir::emitc; namespace { /// Common functions. -DenseIntElementsAttr getI64ElementsAttr(const ArrayRef values, - MLIRContext *ctx) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, IntegerType::get(ctx, 64)); - - return DenseIntElementsAttr::get(ty, values); -} - SmallVector indexSequence(int64_t n, MLIRContext *ctx) { return llvm::to_vector<2>( llvm::map_range(llvm::seq(0, n), [&ctx](int64_t i) -> Attribute { @@ -115,9 +107,9 @@ class GenericConvOpConversion : public OpConversionPattern { ArrayAttr args = rewriter.getArrayAttr({ rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), - getI64ElementsAttr(convOp.getPad(), convOp.getContext()), - getI64ElementsAttr(convOp.getStride(), convOp.getContext()), - getI64ElementsAttr(convOp.getDilation(), convOp.getContext()), + rewriter.getI64TensorAttr(convOp.getPad()), + rewriter.getI64TensorAttr(convOp.getStride()), + rewriter.getI64TensorAttr(convOp.getDilation()), }); // clang-format on @@ -161,9 +153,9 @@ class GenericPoolOpConversion : public OpConversionPattern { // clang-format off ArrayAttr args = rewriter.getArrayAttr({ rewriter.getIndexAttr(0), - getI64ElementsAttr(poolOp.getPad(), poolOp.getContext()), - getI64ElementsAttr(poolOp.getStride(), poolOp.getContext()), - getI64ElementsAttr(poolOp.getKernel(), poolOp.getContext()), + rewriter.getI64TensorAttr(poolOp.getPad()), + rewriter.getI64TensorAttr(poolOp.getStride()), + rewriter.getI64TensorAttr(poolOp.getKernel()), }); // clang-format on