Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Commit

Permalink
Integrate LLVM at llvm/llvm-project@511ba45 (#395)
Browse files Browse the repository at this point in the history
Updates LLVM usage to match llvm/llvm-project@511ba45. Further updates
the StableHLO submodule to openxla/stablehlo@42387b0.
  • Loading branch information
marbre authored Dec 11, 2023
1 parent ec795ef commit 0fa488a
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 272 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6ae7b735dbd50eb7ade1573a86d037a2943e679c
511ba45a47d6f9e48ad364181830c9fb974135b2
8 changes: 4 additions & 4 deletions include/emitc/Conversion/EmitCCommon/GenericOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace {
using namespace mlir;
using namespace mlir::emitc;

/// Convert a common operation into an `emitc.call` operation.
/// Convert a common operation into an `emitc.call_opaque` operation.
template <typename SrcOp, typename Adaptor = typename SrcOp::Adaptor>
class GenericOpConversion : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;
Expand Down Expand Up @@ -57,9 +57,9 @@ class GenericOpConversion : public OpConversionPattern<SrcOp> {
templateArgs = ArrayAttr::get(srcOp.getContext(), templateArguments);
}

rewriter.replaceOpWithNewOp<emitc::CallOp>(srcOp, srcOp.getType(), callee,
args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(srcOp, srcOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace mlir::emitc;

namespace {

// Convert `arith.index_cast` into an `emitc.call` operation.
// Convert `arith.index_cast` into an `emitc.call_opaque` operation.
class IndexCastOpConversion : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern<arith::IndexCastOp>::OpConversionPattern;

Expand All @@ -43,7 +43,7 @@ class IndexCastOpConversion : public OpConversionPattern<arith::IndexCastOp> {
Type resultType = indexCastOp.getResult().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
indexCastOp, indexCastOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/StablehloToEmitC/StablehloRegionOpsToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct ConvertStablehloRegionOpsToEmitCPass

ArrayAttr templateArgs = ArrayAttr::get(ctx, templateArguments);

emitc::CallOp callOp = builder.create<emitc::CallOp>(
emitc::CallOpaqueOp callOpaqueOp = builder.create<emitc::CallOpaqueOp>(
op.getLoc(), op.getResultTypes(), callee, args, templateArgs, operands);
op.replaceAllUsesWith(callOp);
op.replaceAllUsesWith(callOpaqueOp);
op.erase();
return success();
}
Expand Down Expand Up @@ -217,9 +217,9 @@ struct ConvertStablehloRegionOpsToEmitCPass
ArrayAttr templateArgs =
ArrayAttr::get(ctx, {TypeAttr::get(op.getResult(0).getType())});

emitc::CallOp callOp = builder.create<emitc::CallOp>(
emitc::CallOpaqueOp callOpaqueOp = builder.create<emitc::CallOpaqueOp>(
op.getLoc(), op.getType(0), callee, args, templateArgs, operands);
op.replaceAllUsesWith(callOp);
op.replaceAllUsesWith(callOpaqueOp);
op.erase();
return success();
}
Expand Down
70 changes: 36 additions & 34 deletions lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class ConstOpConversion : public OpRewritePattern<stablehlo::ConstantOp> {
}
};

/// Convert `stablehlo.batch_norm_inference` into an `emitc.call` operation.
/// Convert `stablehlo.batch_norm_inference` into an `emitc.call_opaque`
/// operation.
class BatchNormInferenceOpConversion
: public OpConversionPattern<stablehlo::BatchNormInferenceOp> {

Expand All @@ -83,15 +84,15 @@ class BatchNormInferenceOpConversion
{TypeAttr::get(batchNormInferenceOp.getResult().getType()),
TypeAttr::get(adaptor.getScale().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
batchNormInferenceOp, batchNormInferenceOp.getType(), callee, args,
templateArgs, adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.broadcast_in_dim` into an `emitc.call` operation.
/// Convert `stablehlo.broadcast_in_dim` into an `emitc.call_opaque` operation.
class BroadcastInDimOpConversion
: public OpConversionPattern<stablehlo::BroadcastInDimOp> {

Expand All @@ -116,15 +117,15 @@ class BroadcastInDimOpConversion
ArrayAttr templateArgs = rewriter.getArrayAttr(
{TypeAttr::get(broadcastInDimOp.getResult().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
broadcastInDimOp, broadcastInDimOp.getType(), callee, args,
templateArgs, adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.concatenate` into an `emitc.call` operation.
/// Convert `stablehlo.concatenate` into an `emitc.call_opaque` operation.
class ConcatenateOpConversion
: public OpConversionPattern<stablehlo::ConcatenateOp> {

Expand All @@ -144,15 +145,15 @@ class ConcatenateOpConversion
{rewriter.getI64IntegerAttr(concatenateOp.getDimension()),
TypeAttr::get(concatenateOp.getResult().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
concatenateOp, concatenateOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.convolution` into an `emitc.call` operation.
/// Convert `stablehlo.convolution` into an `emitc.call_opaque` operation.
class ConvOpConversion : public OpConversionPattern<stablehlo::ConvolutionOp> {

public:
Expand Down Expand Up @@ -206,15 +207,15 @@ class ConvOpConversion : public OpConversionPattern<stablehlo::ConvolutionOp> {
TypeAttr::get(adaptor.getLhs().getType()),
TypeAttr::get(adaptor.getRhs().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(convOp, convOp.getType(), callee,
args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(convOp, convOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.compare` into an `emitc.call` operation.
/// Convert `stablehlo.compare` into an `emitc.call_opaque` operation.
class CompareOpConversion : public OpConversionPattern<stablehlo::CompareOp> {
using OpConversionPattern<stablehlo::CompareOp>::OpConversionPattern;

Expand Down Expand Up @@ -252,15 +253,15 @@ class CompareOpConversion : public OpConversionPattern<stablehlo::CompareOp> {
{TypeAttr::get(elementType),
emitc::OpaqueAttr::get(ctx, functionName.value())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(compareOp, compareOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
compareOp, compareOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.get_tuple_element` into an `emitc.call` operation.
/// Convert `stablehlo.get_tuple_element` into an `emitc.call_opaque` operation.
class GetTupleElementOpConversion
: public OpConversionPattern<stablehlo::GetTupleElementOp> {
using OpConversionPattern<stablehlo::GetTupleElementOp>::OpConversionPattern;
Expand All @@ -282,15 +283,15 @@ class GetTupleElementOpConversion
ArrayAttr templateArgs = rewriter.getArrayAttr(
{IntegerAttr::get(rewriter.getIntegerType(32), index)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
getTupleElementOp, getTupleElementOp.getType(), callee, args,
templateArgs, adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.slice` into an `emitc.call` operation.
/// Convert `stablehlo.slice` into an `emitc.call_opaque` operation.
class SliceOpConversion : public OpConversionPattern<stablehlo::SliceOp> {
using OpConversionPattern<stablehlo::SliceOp>::OpConversionPattern;

Expand All @@ -316,15 +317,15 @@ class SliceOpConversion : public OpConversionPattern<stablehlo::SliceOp> {
ArrayAttr templateArgs =
rewriter.getArrayAttr({TypeAttr::get(sliceOp.getResult().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(sliceOp, sliceOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(sliceOp, sliceOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.dynamic_slice` into an `emitc.call` operation.
/// Convert `stablehlo.dynamic_slice` into an `emitc.call_opaque` operation.
class DynamicSliceOpConversion
: public OpConversionPattern<stablehlo::DynamicSliceOp> {
using OpConversionPattern<stablehlo::DynamicSliceOp>::OpConversionPattern;
Expand All @@ -350,15 +351,16 @@ class DynamicSliceOpConversion
ArrayAttr templateArgs = rewriter.getArrayAttr(
{TypeAttr::get(dynamicSliceOp.getResult().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
dynamicSliceOp, dynamicSliceOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.dynamic_update_slice` into an `emitc.call` operation.
/// Convert `stablehlo.dynamic_update_slice` into an `emitc.call_opaque`
/// operation.
class DynamicUpdateSliceOpConversion
: public OpConversionPattern<stablehlo::DynamicUpdateSliceOp> {
using OpConversionPattern<
Expand All @@ -381,15 +383,15 @@ class DynamicUpdateSliceOpConversion
ArrayAttr templateArgs =
rewriter.getArrayAttr({TypeAttr::get(adaptor.getUpdate().getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
dynamicUpdateSliceOp, dynamicUpdateSliceOp.getType(), callee, args,
templateArgs, adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.pad` into an `emitc.call` operation.
/// Convert `stablehlo.pad` into an `emitc.call_opaque` operation.
class PadOpConversion : public OpConversionPattern<stablehlo::PadOp> {
using OpConversionPattern<stablehlo::PadOp>::OpConversionPattern;

Expand All @@ -415,15 +417,15 @@ class PadOpConversion : public OpConversionPattern<stablehlo::PadOp> {
Type resultType = padOp.getResult().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(padOp, padOp.getType(), callee,
args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(padOp, padOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.transpose` into an `emitc.call` operation.
/// Convert `stablehlo.transpose` into an `emitc.call_opaque` operation.
class TransposeOpConversion
: public OpConversionPattern<stablehlo::TransposeOp> {
using OpConversionPattern<stablehlo::TransposeOp>::OpConversionPattern;
Expand All @@ -447,15 +449,15 @@ class TransposeOpConversion
Type resultType = transposeOp.getResult().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
transposeOp, transposeOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `stablehlo.rng` into an `emitc.call` operation.
/// Convert `stablehlo.rng` into an `emitc.call_opaque` operation.
class RngOpConversion : public OpConversionPattern<stablehlo::RngOp> {

public:
Expand All @@ -478,9 +480,9 @@ class RngOpConversion : public OpConversionPattern<stablehlo::RngOp> {
ArrayAttr templateArgs =
rewriter.getArrayAttr({TypeAttr::get(rngOp.getType())});

rewriter.replaceOpWithNewOp<emitc::CallOp>(rngOp, rngOp.getType(), callee,
args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(rngOp, rngOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/TensorToEmitC/TensorToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace mlir::emitc;

namespace {

/// Convert `tensor.extract` into an `emitc.call` operation.
/// Convert `tensor.extract` into an `emitc.call_opaque` operation.
class ExtractOpConversion : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern<tensor::ExtractOp>::OpConversionPattern;

Expand All @@ -47,15 +47,15 @@ class ExtractOpConversion : public OpConversionPattern<tensor::ExtractOp> {
ArrayAttr args;
ArrayAttr templateArgs;

rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
indexCastOp, indexCastOp.getType(), callee, args, templateArgs,
adaptor.getOperands());

return success();
}
};

/// Convert `tensor.splat` into an `emitc.call` operation.
/// Convert `tensor.splat` into an `emitc.call_opaque` operation.
class SplatOpConversion : public OpConversionPattern<tensor::SplatOp> {
using OpConversionPattern<tensor::SplatOp>::OpConversionPattern;

Expand All @@ -73,9 +73,9 @@ class SplatOpConversion : public OpConversionPattern<tensor::SplatOp> {
Type resultType = splatOp.getResult().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(splatOp, splatOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(splatOp, splatOp.getType(),
callee, args, templateArgs,
adaptor.getOperands());

return success();
}
Expand Down
Loading

0 comments on commit 0fa488a

Please sign in to comment.