From 8900f67c08acb808576716e2d375ae7abc7e33ba Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 10 Sep 2024 21:24:51 -0700 Subject: [PATCH] scatter reduce lowering with include_self=False --- .../dynamo/lowering/_decompositions.py | 94 +++- .../py/dynamo/lowering/test_decompositions.py | 479 ++++++++++++++++-- 2 files changed, 529 insertions(+), 44 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 4ffbcfdedb..2f90fdf2c0 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -303,21 +303,30 @@ def __new__(cls, description: Any, func: Any) -> Any: obj.func = func return obj - def reduce_operation_with_scatter( + def reduce_operation_with_scatter_include_self( self, - operation_lhs: Any, - initial_tensor: torch.Tensor, - dim: int, - index_tensor: torch.Tensor, - src_tensor: torch.Tensor, - ) -> Any: + operation_lhs, + initial_tensor, + dim, + index_tensor, + src_tensor, + min_ele=float("-inf"), + max_ele=float("inf"), + include_self=True, + ): scatter_tensor = None if self == ReduceOperation.SUM or self == ReduceOperation.MEAN: scatter_tensor = torch.zeros_like(initial_tensor) elif self == ReduceOperation.PROD: scatter_tensor = torch.ones_like(initial_tensor) - elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX: + elif self == ReduceOperation.AMAX: scatter_tensor = initial_tensor + if not (include_self): + scatter_tensor = torch.full_like(initial_tensor, min_ele) + elif self == ReduceOperation.AMIN: + scatter_tensor = initial_tensor + if not (include_self): + scatter_tensor = torch.full_like(initial_tensor, max_ele) else: # This case would not be encountered from torch itself print("Invalid Operation for Reduce op!!") @@ -341,13 +350,39 @@ def scatter_reduce_decomposition( include_self: bool = True, ) -> torch.Tensor: scatter_loop_tensor = input_tensor + MAX_ELE = 0 + MIN_ELE = 0 + if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32: + MAX_ELE = 2147483647 + MIN_ELE = -2147483648 + else: + MAX_ELE = float("inf") + MIN_ELE = float("-inf") + if not (include_self): + if reduce == "sum" or reduce == "mean": + scatter_loop_tensor = torch.scatter( + scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor) + ) + if reduce == "prod": + scatter_loop_tensor = torch.scatter( + scatter_loop_tensor, dim, index, torch.ones_like(src_tensor) + ) + if reduce == "amax": + src_red_tensor = torch.full_like(src_tensor, MIN_ELE) + scatter_loop_tensor = torch.scatter( + scatter_loop_tensor, dim, index, src_red_tensor + ) + if reduce == "amin": + src_red_tensor = torch.full_like(src_tensor, MAX_ELE) + scatter_loop_tensor = torch.scatter( + scatter_loop_tensor, dim, index, src_red_tensor + ) + device_input_tensor = input_tensor.device # required for mean reduce operation scatter_count_tensor = torch.zeros_like(input_tensor) src_shape = list(src_tensor.shape) src_dim = src_shape[dim] - if not include_self: - raise AssertionError("include_self False for scatter reduce not yet supported") for i in range(0, src_dim): src_slice = torch.select(src_tensor, dim, i) index_slice = torch.select(index, dim, i) @@ -371,20 +406,53 @@ def scatter_reduce_decomposition( dim, index_slice, torch.ones_like(src_slice), + MIN_ELE, + MAX_ELE, + include_self, ) elif reduce == "amax": reduceOp = ReduceOperation.AMAX elif reduce == "amin": reduceOp = ReduceOperation.AMIN - scatter_loop_tensor = reduceOp.reduce_operation_with_scatter( - scatter_loop_tensor, input_tensor, dim, index_slice, src_slice + scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self( + scatter_loop_tensor, + input_tensor, + dim, + index_slice, + src_slice, + MIN_ELE, + MAX_ELE, + include_self, ) if reduce == "mean": scatter_loop_tensor = torch.div( scatter_loop_tensor, - torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)), + ( + torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)) + if include_self + else scatter_count_tensor + ), rounding_mode="trunc", ) + # for include_self cases for amax and amin additional processing is required + # except for the max elements in amax, rest are -inf or INT_MIN + # except for the min elements in amin, rest are +inf or INT_MAX + if reduce == "amax" and not (include_self): + # the relevant should be min, rest original + return torch.max( + scatter_loop_tensor, + torch.scatter( + input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE) + ), + ) + if reduce == "amin" and not (include_self): + # the relevant should be min, rest original + return torch.min( + scatter_loop_tensor, + torch.scatter( + input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE) + ), + ) return scatter_loop_tensor diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 4082caafb1..26efc6bbea 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1131,27 +1131,29 @@ def forward(self, input): @parameterized.expand( [ - ############################sum########################### + #########################sum########################### ( - "scatter_reduce_add_zero_dim_indexOne_constant", + "scatter_reduce_add_zero_dim_indexOne_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), {torch.ops.aten.add.Tensor}, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", + True, ), ( - "scatter_reduce_add_zero_dim_indexTwo_constant", + "scatter_reduce_add_zero_dim_indexTwo_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", + True, ), ( - "scatter_reduce_add_one_dim_indexOne_constant", + "scatter_reduce_add_one_dim_indexOne_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), @@ -1161,9 +1163,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", + True, ), ( - "scatter_reduce_add_one_dim_indexTwo_constant", + "scatter_reduce_add_one_dim_indexTwo_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), @@ -1173,9 +1176,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", + True, ), ( - "scatter_reduce_add_one_dim_indexOne_constant_3D", + "scatter_reduce_add_one_dim_indexOne_constant_3D_include_self_True", 1, torch.tensor( [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] @@ -1190,10 +1194,89 @@ def forward(self, input): }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "sum", + True, ), - ###########################prod########################### + ######################### add include_self= False############################# ( - "scatter_reduce_prod_zero_dim_indexOne_constant", + "scatter_reduce_add_zero_dim_indexOne_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor}, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + False, + ), + ( + "scatter_reduce_add_zero_dim_indexTwo_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + False, + ), + ( + "scatter_reduce_add_one_dim_indexOne_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + False, + ), + ( + "scatter_reduce_add_one_dim_indexTwo_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + False, + ), + ( + "scatter_reduce_add_one_dim_indexOne_constant_3D_include_self_False", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "sum", + False, + ), + ( + "scatter_reduce_add_one_include_self_False", + 0, + torch.tensor([0, 1, 0, 1]).cuda(), + torch.tensor([1, 2, 3, 4], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + }, + torch.tensor([1, 2, 3, 4, 6, 7], dtype=torch.int32).cuda(), + "sum", + False, + ), + ############################prod########################### + ( + "scatter_reduce_prod_zero_dim_indexOne_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), @@ -1203,9 +1286,10 @@ def forward(self, input): }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", + True, ), ( - "scatter_reduce_prod_zero_dim_indexTwo_constant", + "scatter_reduce_prod_zero_dim_indexTwo_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), @@ -1215,9 +1299,10 @@ def forward(self, input): }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", + True, ), ( - "scatter_reduce_prod_one_dim_indexOne_constant", + "scatter_reduce_prod_one_dim_indexOne_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), @@ -1227,9 +1312,10 @@ def forward(self, input): }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", + True, ), ( - "scatter_reduce_prod_one_dim_indexTwo_constant", + "scatter_reduce_prod_one_dim_indexTwo_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), @@ -1239,9 +1325,10 @@ def forward(self, input): }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", + True, ), ( - "scatter_reduce_prod_one_dim_indexTwo_constant_3D", + "scatter_reduce_prod_one_dim_indexTwo_constant_3D_include_self_True", 1, torch.tensor( [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] @@ -1256,10 +1343,170 @@ def forward(self, input): }, torch.ones(3, 5, 6, dtype=torch.int32).cuda(), "prod", + True, + ), + ########################## prod include_self= False############################# + ( + "scatter_reduce_prod_zero_dim_indexOne_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + False, + ), + ( + "scatter_reduce_prod_zero_dim_indexTwo_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + False, + ), + ( + "scatter_reduce_prod_one_dim_indexOne_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + False, + ), + ( + "scatter_reduce_prod_one_dim_indexTwo_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + False, + ), + ( + "scatter_reduce_prod_one_dim_indexTwo_constant_3D_include_self_False", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.ones(3, 5, 6, dtype=torch.int32).cuda(), + "prod", + False, + ), + ( + "scatter_reduce_prod_one_include_self_False", + 0, + torch.tensor([0, 1, 0, 1]).cuda(), + torch.tensor([1, 2, 3, 4], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + }, + torch.tensor([1, 2, 3, 4, 6, 7], dtype=torch.int32).cuda(), + "prod", + False, + ), + ##########################mean########################### + ( + "scatter_reduce_mean_zero_dim_indexOne_constant_include_self_True", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + True, + ), + ( + "scatter_reduce_mean_zero_dim_indexTwo_constant_include_self_True", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + True, + ), + ( + "scatter_reduce_mean_one_dim_indexOne_constant_include_self_True", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + True, + ), + ( + "scatter_reduce_mean_one_dim_indexTwo_constant_include_self_True", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + True, ), - # #############################mean########################### ( - "scatter_reduce_mean_zero_dim_indexOne_constant", + "scatter_reduce_mean_one_dim_indexTwo_constant_3D_include_self_True", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "mean", + True, + ), + ########################## mean include_self= False############################# + ( + "scatter_reduce_mean_zero_dim_indexOne_constant_include_self_False", 0, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), @@ -1269,9 +1516,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", + False, ), ( - "scatter_reduce_mean_zero_dim_indexTwo_constant", + "scatter_reduce_mean_zero_dim_indexTwo_constant_include_self_False", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), @@ -1282,9 +1530,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", + False, ), ( - "scatter_reduce_mean_one_dim_indexOne_constant", + "scatter_reduce_mean_one_dim_indexOne_constant_include_self_False", 1, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), @@ -1295,9 +1544,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", + False, ), ( - "scatter_reduce_mean_one_dim_indexTwo_constant", + "scatter_reduce_mean_one_dim_indexTwo_constant_include_self_False", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), @@ -1308,9 +1558,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", + False, ), ( - "scatter_reduce_mean_one_dim_indexTwo_constant_3D", + "scatter_reduce_mean_one_dim_indexTwo_constant_3D_include_self_False", 1, torch.tensor( [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] @@ -1326,10 +1577,11 @@ def forward(self, input): }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "mean", + False, ), - # #############################amax########################### + #############################amax########################### ( - "scatter_reduce_amax_zero_dim_indexOne_constant", + "scatter_reduce_amax_zero_dim_indexOne_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), @@ -1339,9 +1591,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amax", + True, ), ( - "scatter_reduce_amax_zero_dim_indexTwo_constant", + "scatter_reduce_amax_zero_dim_indexTwo_constant_include_self_True", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), @@ -1351,9 +1604,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amax", + True, ), ( - "scatter_reduce_amax_one_dim_indexOne_constant", + "scatter_reduce_amax_one_dim_indexOne_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), @@ -1363,9 +1617,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amax", + True, ), ( - "scatter_reduce_amax_one_dim_indexTwo_constant", + "scatter_reduce_amax_one_dim_indexTwo_constant_include_self_True", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), @@ -1375,9 +1630,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amax", + True, ), ( - "scatter_reduce_amax_one_dim_indexTwo_constant_3D", + "scatter_reduce_amax_one_dim_indexTwo_constant_3D_include_self_True", 1, torch.tensor( [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] @@ -1392,10 +1648,153 @@ def forward(self, input): }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "amax", + True, + ), + # ######################### amax include_self= False############################# + ( + "scatter_reduce_amax_zero_dim_indexOne_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + False, + ), + ( + "scatter_reduce_amax_zero_dim_indexTwo_constant_include_self_False", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + False, + ), + ( + "scatter_reduce_amax_one_dim_indexOne_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + False, + ), + ( + "scatter_reduce_amax_one_dim_indexTwo_constant_include_self_False", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + False, + ), + ( + "scatter_reduce_amax_one_dim_indexTwo_constant_3D_include_self_False", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "amax", + False, ), # #############################amin########################### ( - "scatter_reduce_amin_zero_dim_indexOne_constant", + "scatter_reduce_amin_zero_dim_indexOne_constant_include_self_True", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + True, + ), + ( + "scatter_reduce_amin_zero_dim_indexTwo_constant_include_self_True", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + True, + ), + ( + "scatter_reduce_amin_one_dim_indexOne_constant_include_self_True", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + True, + ), + ( + "scatter_reduce_amin_one_dim_indexTwo_constant_include_self_True", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + True, + ), + ( + "scatter_reduce_amin_one_dim_indexTwo_constant_3D_include_self_True", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "amin", + True, + ), + # ######################### amin include_self= False############################# + ( + "scatter_reduce_amin_zero_dim_indexOne_constant_include_self_False", 0, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), @@ -1405,9 +1804,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amin", + False, ), ( - "scatter_reduce_amin_zero_dim_indexTwo_constant", + "scatter_reduce_amin_zero_dim_indexTwo_constant_include_self_False", 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), @@ -1417,9 +1817,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amin", + False, ), ( - "scatter_reduce_amin_one_dim_indexOne_constant", + "scatter_reduce_amin_one_dim_indexOne_constant_include_self_False", 1, torch.tensor([[0, 1, 2, 0]]).cuda(), torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), @@ -1429,9 +1830,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amin", + False, ), ( - "scatter_reduce_amin_one_dim_indexTwo_constant", + "scatter_reduce_amin_one_dim_indexTwo_constant_include_self_False", 1, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), @@ -1441,9 +1843,10 @@ def forward(self, input): }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "amin", + False, ), ( - "scatter_reduce_amin_one_dim_indexTwo_constant_3D", + "scatter_reduce_amin_one_dim_indexTwo_constant_3D_include_self_False", 1, torch.tensor( [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] @@ -1458,11 +1861,20 @@ def forward(self, input): }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "amin", + False, ), ] ) def test_scatter_reduce( - self, _, dim, index, src, expected_ops_param, input_reduce_op, reduce_op_str + self, + _, + dim, + index, + src, + expected_ops_param, + input_reduce_op, + reduce_op_str, + include_self, ): class TestModule(torch.nn.Module): def __init__(self): @@ -1471,7 +1883,12 @@ def __init__(self): def forward(self, input): return torch.ops.aten.scatter_reduce_.two( - input, dim, index, src, reduce=reduce_op_str + input, + dim, + index, + src, + reduce=reduce_op_str, + include_self=include_self, ) # Operations expected to be included in the traced graph after decompositions