Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Commit

Permalink
Replace {i,getI}64ElementsAttr with rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre committed Jan 11, 2024
1 parent 7d568bd commit 62f78e5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 52 deletions.
57 changes: 19 additions & 38 deletions lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,6 @@ using namespace mlir::emitc;
namespace {

/// Common functions.
/// Adopted from mlir-hlo.
DenseIntElementsAttr i64ElementsAttr(int64_t value, size_t count,
MLIRContext *ctx) {
RankedTensorType ty = RankedTensorType::get({static_cast<int64_t>(count)},
IntegerType::get(ctx, 64));
SmallVector<int64_t, 4> values(count, value);
return DenseIntElementsAttr::get(ty, values);
}

DenseIntElementsAttr getI64ElementsAttr(const ArrayRef<long> values,
MLIRContext *ctx) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, IntegerType::get(ctx, 64));

return DenseIntElementsAttr::get(ty, values);
}

SmallVector<Attribute, 2> indexSequence(int64_t n, MLIRContext *ctx) {
return llvm::to_vector<2>(
llvm::map_range(llvm::seq<int64_t>(0, n), [&ctx](int64_t i) -> Attribute {
Expand Down Expand Up @@ -172,7 +155,6 @@ class ConvOpConversion : public OpConversionPattern<stablehlo::ConvolutionOp> {
LogicalResult
matchAndRewrite(stablehlo::ConvolutionOp convOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *ctx = convOp.getContext();

StringRef funcName = "emitc::stablehlo::convolution";
StringAttr callee = rewriter.getStringAttr(funcName);
Expand Down Expand Up @@ -201,14 +183,14 @@ class ConvOpConversion : public OpConversionPattern<stablehlo::ConvolutionOp> {
convOp.getDimensionNumbers().getOutputSpatialDimensions()));
arguments.push_back(convOp.getFeatureGroupCountAttr());

arguments.push_back(
convOp.getPadding().value_or(i64ElementsAttr(0, 2, ctx)));
arguments.push_back(
convOp.getLhsDilation().value_or(i64ElementsAttr(1, 2, ctx)));
arguments.push_back(
convOp.getRhsDilation().value_or(i64ElementsAttr(1, 2, ctx)));
arguments.push_back(
convOp.getWindowStrides().value_or(i64ElementsAttr(1, 2, ctx)));
arguments.push_back(convOp.getPadding().value_or(
rewriter.getI64TensorAttr(SmallVector<int64_t>(0, 2))));
arguments.push_back(convOp.getLhsDilation().value_or(
rewriter.getI64TensorAttr(SmallVector<int64_t>(1, 2))));
arguments.push_back(convOp.getRhsDilation().value_or(
rewriter.getI64TensorAttr(SmallVector<int64_t>(1, 2))));
arguments.push_back(convOp.getWindowStrides().value_or(
rewriter.getI64TensorAttr(SmallVector<int64_t>(1, 2))));

ArrayAttr args = rewriter.getArrayAttr(arguments);
ArrayAttr templateArgs =
Expand Down Expand Up @@ -318,12 +300,11 @@ class SliceOpConversion : public OpConversionPattern<stablehlo::SliceOp> {
SmallVector<Attribute, 2> arguments =
indexSequence(adaptor.getOperands().size(), sliceOp.getContext());

arguments.push_back(getI64ElementsAttr(sliceOp.getStartIndicesAttr(),
sliceOp.getContext()));
arguments.push_back(getI64ElementsAttr(sliceOp.getLimitIndicesAttr(),
sliceOp.getContext()));
arguments.push_back(
getI64ElementsAttr(sliceOp.getStridesAttr(), sliceOp.getContext()));
rewriter.getI64TensorAttr(sliceOp.getStartIndicesAttr()));
arguments.push_back(
rewriter.getI64TensorAttr(sliceOp.getLimitIndicesAttr()));
arguments.push_back(rewriter.getI64TensorAttr(sliceOp.getStridesAttr()));

ArrayAttr args = rewriter.getArrayAttr(arguments);

Expand Down Expand Up @@ -357,8 +338,8 @@ class DynamicSliceOpConversion
SmallVector<Attribute, 2> arguments = indexSequence(
adaptor.getOperands().size(), dynamicSliceOp.getContext());

arguments.push_back(getI64ElementsAttr(dynamicSliceOp.getSliceSizesAttr(),
dynamicSliceOp.getContext()));
arguments.push_back(
rewriter.getI64TensorAttr(dynamicSliceOp.getSliceSizesAttr()));

ArrayAttr args = rewriter.getArrayAttr(arguments);

Expand Down Expand Up @@ -423,11 +404,11 @@ class PadOpConversion : public OpConversionPattern<stablehlo::PadOp> {
indexSequence(adaptor.getOperands().size(), padOp.getContext());

arguments.push_back(
getI64ElementsAttr(padOp.getEdgePaddingLowAttr(), padOp.getContext()));
rewriter.getI64TensorAttr(padOp.getEdgePaddingLowAttr()));
arguments.push_back(
getI64ElementsAttr(padOp.getEdgePaddingHighAttr(), padOp.getContext()));
rewriter.getI64TensorAttr(padOp.getEdgePaddingHighAttr()));
arguments.push_back(
getI64ElementsAttr(padOp.getInteriorPaddingAttr(), padOp.getContext()));
rewriter.getI64TensorAttr(padOp.getInteriorPaddingAttr()));

ArrayAttr args = rewriter.getArrayAttr(arguments);

Expand Down Expand Up @@ -460,8 +441,8 @@ class TransposeOpConversion
SmallVector<Attribute> arguments =
indexSequence(adaptor.getOperands().size(), transposeOp.getContext());

arguments.push_back(getI64ElementsAttr(transposeOp.getPermutationAttr(),
transposeOp.getContext()));
arguments.push_back(
rewriter.getI64TensorAttr(transposeOp.getPermutationAttr()));
ArrayAttr args = rewriter.getArrayAttr(arguments);

Type resultType = transposeOp.getResult().getType();
Expand Down
20 changes: 6 additions & 14 deletions lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ using namespace mlir::emitc;
namespace {

/// Common functions.
DenseIntElementsAttr getI64ElementsAttr(const ArrayRef<long> values,
MLIRContext *ctx) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, IntegerType::get(ctx, 64));

return DenseIntElementsAttr::get(ty, values);
}

SmallVector<Attribute, 2> indexSequence(int64_t n, MLIRContext *ctx) {
return llvm::to_vector<2>(
llvm::map_range(llvm::seq<int64_t>(0, n), [&ctx](int64_t i) -> Attribute {
Expand Down Expand Up @@ -115,9 +107,9 @@ class GenericConvOpConversion : public OpConversionPattern<SrcOp> {
ArrayAttr args = rewriter.getArrayAttr({
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(1),
getI64ElementsAttr(convOp.getPad(), convOp.getContext()),
getI64ElementsAttr(convOp.getStride(), convOp.getContext()),
getI64ElementsAttr(convOp.getDilation(), convOp.getContext()),
rewriter.getI64TensorAttr(convOp.getPad()),
rewriter.getI64TensorAttr(convOp.getStride()),
rewriter.getI64TensorAttr(convOp.getDilation()),
});
// clang-format on

Expand Down Expand Up @@ -161,9 +153,9 @@ class GenericPoolOpConversion : public OpConversionPattern<SrcOp> {
// clang-format off
ArrayAttr args = rewriter.getArrayAttr({
rewriter.getIndexAttr(0),
getI64ElementsAttr(poolOp.getPad(), poolOp.getContext()),
getI64ElementsAttr(poolOp.getStride(), poolOp.getContext()),
getI64ElementsAttr(poolOp.getKernel(), poolOp.getContext()),
rewriter.getI64TensorAttr(poolOp.getPad()),
rewriter.getI64TensorAttr(poolOp.getStride()),
rewriter.getI64TensorAttr(poolOp.getKernel()),
});
// clang-format on

Expand Down

0 comments on commit 62f78e5

Please sign in to comment.