Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of 297c2709 (May 20) (43) #276

Merged
merged 49 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
297c270
onnx.Resize and aten._interpolate : allow n spatial dims. (#3368)
zjgarvey May 20, 2024
c0e7d26
[torch-mlir][sparse] inference mode for sparse GCN test (#3369)
aartbik May 21, 2024
b870729
[torch] Fix `onnx.MaxPool` lowering (#3133)
vivekkhandelwal1 May 21, 2024
c2c1c2c
[FxImporter] Fix failed e2e case (#3365)
penguin-wwy May 21, 2024
fcf4887
[ONNX] Implement Softsign op (#3373)
RattataKing May 21, 2024
560ca24
[torch-mlir][sparse] replace xavier with ones initialization (#3374)
aartbik May 22, 2024
6e48557
[Pipeline] Use dedicated simplification pipeline for TorchDynamo fron…
sjain-stanford May 22, 2024
52be4bd
[ONNX] Fix bugs for the `onnx.OneHot` operator (#3334)
angelz913 May 22, 2024
972d47b
[FxImporter] Fix constant bool tensor (#3375)
penguin-wwy May 22, 2024
4d7cdba
[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTri…
May 22, 2024
f4bfe3f
Bump llvm and stablehlo (#3377)
qingyunqu May 22, 2024
2e194e1
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
angelz913 May 22, 2024
d924d00
[FxImporter] Fix primitive type in return (#3379)
penguin-wwy May 23, 2024
43f961e
[MLIR] Fix 64-bit product during aten.view lowering (#3378)
Shukla-Gaurav May 23, 2024
5bb1a65
[Stablehlo] refactor reduction lowering and support aten.amin (#3383)
qingyunqu May 23, 2024
27169dc
Replace some depreciated uses of cast (#3343)
zjgarvey May 23, 2024
28aeb04
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389)
qingyunqu May 26, 2024
05929f9
enhance verbose option in e2e_testing (#3390)
qingyunqu May 27, 2024
e0a5adb
[Torch] fix aten.linear's decomposition (#3391)
qingyunqu May 27, 2024
a5d3b54
[FxImporter] Fix embedding bag (#3387)
penguin-wwy May 29, 2024
23d2d66
Fix error when attempting to read elided onnx constants (#3398)
renxida May 29, 2024
1f544c3
[NFC] Remove unused header files (#3386)
penguin-wwy May 30, 2024
e4be197
[FxImporter] Fix transpose rank zero (#3382)
penguin-wwy May 30, 2024
d7b8f00
[ONNX] Add OnnxToTorch Lowering for LpNormalization op (#3397)
vivekkhandelwal1 May 30, 2024
074098d
Modifies onnx resize lowering to fix numerical issues (#3381)
zjgarvey May 31, 2024
4e05e2c
[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp…
qingyunqu May 31, 2024
cb36327
[AutoBump] Merge with fixes of 297c2709 (May 20)
mgehre-amd Aug 27, 2024
8e802ea
[AutoBump] Merge with 4d7cdba4 (May 22)
mgehre-amd Aug 27, 2024
74a6d69
[AutoBump] Merge with fixes of f4bfe3f9 (May 22)
mgehre-amd Aug 27, 2024
15311b9
[AutoBump] Merge with 05929f91 (May 27)
mgehre-amd Aug 27, 2024
27d14b3
[AutoBump] Merge with fixes of e0a5adb1 (May 27)
mgehre-amd Aug 27, 2024
2d73754
[AutoBump] Merge with d7b8f00d (May 30)
mgehre-amd Aug 27, 2024
04b641b
[AutoBump] Merge with fixes of 074098d2 (May 31)
mgehre-amd Aug 27, 2024
b288137
[AutoBump] Merge with fixes of 4e05e2cd (May 31)
mgehre-amd Aug 27, 2024
db67a0e
Merge branch 'bump_to_99511cef' into bump_to_297c2709
mgehre-amd Aug 28, 2024
8946c23
Merge branch 'bump_to_297c2709' into bump_to_4d7cdba4
mgehre-amd Aug 28, 2024
b582114
Merge branch 'bump_to_4d7cdba4' into bump_to_f4bfe3f9
mgehre-amd Aug 28, 2024
176f877
Merge branch 'bump_to_f4bfe3f9' into bump_to_05929f91
mgehre-amd Aug 28, 2024
483e32b
Merge branch 'bump_to_05929f91' into bump_to_e0a5adb1
mgehre-amd Aug 28, 2024
ace855e
Merge branch 'bump_to_e0a5adb1' into bump_to_d7b8f00d
mgehre-amd Aug 28, 2024
16c2dda
Merge branch 'bump_to_d7b8f00d' into bump_to_074098d2
mgehre-amd Aug 28, 2024
7489bec
Merge branch 'bump_to_074098d2' into bump_to_4e05e2cd
mgehre-amd Aug 29, 2024
226c69a
Merge pull request #277 from Xilinx/bump_to_4d7cdba4
mgehre-amd Sep 9, 2024
c2fdddd
Merge pull request #278 from Xilinx/bump_to_f4bfe3f9
mgehre-amd Sep 9, 2024
5d648ed
Merge pull request #279 from Xilinx/bump_to_05929f91
mgehre-amd Sep 9, 2024
5b76ad7
Merge pull request #280 from Xilinx/bump_to_e0a5adb1
mgehre-amd Sep 9, 2024
8cb883b
Merge pull request #281 from Xilinx/bump_to_d7b8f00d
mgehre-amd Sep 9, 2024
85b4eec
Merge pull request #282 from Xilinx/bump_to_074098d2
mgehre-amd Sep 9, 2024
3d8b237
Merge pull request #283 from Xilinx/bump_to_4e05e2cd
mgehre-amd Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ externals/pytorch/
libtorch*

/build/
.build-cache/
/setup_build/
__pycache__
*.pyc
Expand Down
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 148 files
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ namespace hlo {

using mlir::ConversionPatternRewriter;

// Create chlo::ConstantLikeOp
template <typename T>
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, Value val);

// Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
Expand Down
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13577,6 +13577,31 @@ def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenTensorSplitSectionsOp : Torch_Op<"aten.tensor_split.sections", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$sections,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTensorSplitSectionsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenTensorSplitSectionsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
Expand Down
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions
void createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that lowers the graph IR that is produced by
/// TorchDynamo export into the form expected by torch-verify-backend-contract.
void createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by torch-verify-backend-contract.
Expand Down
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);

ValueTensorType getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
Type dtype);
Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down
40 changes: 20 additions & 20 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
}

MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
auto type = unwrap(t).cast<Torch::OptionalType>();
auto type = cast<Torch::OptionalType>(unwrap(t));
return wrap(type.getContainedType());
}

Expand All @@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
}

size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return type.getContainedTypes().size();
}

MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}

Expand All @@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
}

size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return type.getContainedTypes().size();
}

MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}

Expand All @@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
}

MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
return wrap(cast<Torch::ListType>(unwrap(t)).getContainedType());
}

MlirTypeID torchMlirTorchListTypeGetTypeID() {
Expand Down Expand Up @@ -297,26 +297,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(

MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}

int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
return cast<Torch::NonValueTensorType>(unwrap(t)).getSizes().size();
}

bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasSizes();
}

bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasDtype();
}

int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
auto tensorType = cast<Torch::NonValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
Expand All @@ -329,7 +329,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}

MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
return wrap(cast<Torch::NonValueTensorType>(unwrap(t)).getDtype());
}

MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
Expand Down Expand Up @@ -364,26 +364,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(

MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}

int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
return cast<Torch::ValueTensorType>(unwrap(t)).getSizes().size();
}

bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
return cast<Torch::ValueTensorType>(unwrap(t)).hasSizes();
}

bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
return cast<Torch::ValueTensorType>(unwrap(t)).hasDtype();
}

int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
auto tensorType = cast<Torch::ValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
Expand All @@ -396,7 +396,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}

MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
return wrap(cast<Torch::ValueTensorType>(unwrap(t)).getDtype());
}

MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
Expand Down Expand Up @@ -487,12 +487,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
}

MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getKeyType());
}

MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getValueType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

Expand Down
29 changes: 18 additions & 11 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ LogicalResult windowFunctionImpl(OpBinder binder,
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
dyn_cast<Torch::ValueTensorType>(size.getType()).getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
Expand Down Expand Up @@ -897,8 +897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}

if (DenseResourceElementsAttr attr =
binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
Expand All @@ -909,25 +909,34 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

auto ty = cast<ShapedType>(attr.getType());
ElementsAttr denseAttr;
auto ptr = attr.getRawHandle().getBlob()->getData();
auto ptr = attr.getRawHandle().getBlob();
if (!ptr) {
denseAttr = DenseResourceElementsAttr::get(
ty, "__onnx_constant_not_found_possibly_due_to_being_elided__",
AsmResourceBlob());
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, denseAttr);
return success();
}
auto data = ptr->getData();
if (cast<ShapedType>(attr.getType()).getElementType().isInteger(1)) {
llvm::SmallVector<APInt> newContents;
for (auto val : ptr) {
for (auto val : data) {
APInt apval(1, val);
newContents.push_back(apval);
}
denseAttr = DenseElementsAttr::get(ty, newContents);
} else {
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, data);
}

rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, denseAttr);
return success();
}

if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<ElementsAttr>()) {
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
Expand Down Expand Up @@ -2283,9 +2292,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
cast<Torch::BaseTensorType>(tensors[0].getType())
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Expand Down
Loading
Loading