Skip to content

Commit

Permalink
[AutoBump] Merge with f08bfc4 (Oct 04)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Dec 17, 2024
2 parents adb3f09 + f08bfc4 commit bf1ee4a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
34 changes: 28 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure();

auto shapeSizes = shapeType.getSizes();
int64_t dataRank = dataType.getSizes().size();
ArrayRef<int64_t> dataShape = dataType.getSizes();
int64_t dataRank = dataShape.size();
int64_t shapeRank = shapeSizes.size();
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
return failure();
Expand All @@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// we are using torch implementation Torch::AtenBroadcastToOp which
// takes list of int
for (int i = 0; i < shapeSizes[0]; i++) {
// extract dim from shape
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
loc, selectResultType, shape, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
Value selectDim = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), extract);

if (i + rankDifference >= 0) {
// compute dim to pass to broadcast op. For non-broadcastable dims,
// pass -1
Value dim;
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
// broadcasted
// 2. we will explicitly disallow broadcasting dynamic dims that are
// secretly 1.
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
// Assert dataShape[i + rankDiff] >= selectDim. If both are
// constant, this should fold out.
Value iv =
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
loc, rewriter.getType<Torch::IntType>(), data, iv);
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
Value gtSelect =
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
rewriter.create<Torch::RuntimeAssertOp>(
loc, gtSelect,
rewriter.getStringAttr(
"onnx.Expand input has a dim that is not statically 1; "
"expected this dim >= dim provided shape."));
} else {
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
// dataRank)
// 2. selectDims which correspond to dataShape == 1 get included in
// broadcast
dim = selectDim;
}

dimList.push_back(dim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def import_onnx(contents):
# Import the ONNX model proto from the file contents:
raw_model = onnx.load_from_string(contents)
# since it does not affect current e2e tests, data_prop is left false here
model_proto = onnx.shape_inference.infer_shapes(raw_model)
model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)

# Import the ONNX module into an MLIR module:
context = Context()
Expand Down
20 changes: 8 additions & 12 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
Expand All @@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
// CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
// CHECK-NEXT: torch.runtime.assert %[[GE]]
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
// CHECK: return %[[EXPAND]]
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
Expand Down

0 comments on commit bf1ee4a

Please sign in to comment.