Skip to content

Commit

Permalink
scatter reduce lowering with include_self=False
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 11, 2024
1 parent 501a1e1 commit 47fec01
Show file tree
Hide file tree
Showing 2 changed files with 529 additions and 39 deletions.
89 changes: 81 additions & 8 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,30 @@ def __new__(cls, description, func):
obj.func = func
return obj

def reduce_operation_with_scatter(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
def reduce_operation_with_scatter_include_self(
self,
operation_lhs,
initial_tensor,
dim,
index_tensor,
src_tensor,
min_ele=float("-inf"),
max_ele=float("inf"),
include_self=True,
):
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == ReduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
elif self == ReduceOperation.AMAX:
scatter_tensor = initial_tensor
if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, min_ele)
elif self == ReduceOperation.AMIN:
scatter_tensor = initial_tensor
if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, max_ele)
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")
Expand All @@ -336,13 +350,39 @@ def scatter_reduce_decomposition(
include_self: bool = True,
) -> torch.Tensor:
scatter_loop_tensor = input_tensor
MAX_ELE = 0
MIN_ELE = 0
if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
MAX_ELE = 2147483647
MIN_ELE = -2147483648
else:
MAX_ELE = float("inf")
MIN_ELE = float("-inf")
if not (include_self):
if reduce == "sum" or reduce == "mean":
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
)
if reduce == "prod":
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
)
if reduce == "amax":
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, src_red_tensor
)
if reduce == "amin":
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, src_red_tensor
)

device_input_tensor = input_tensor.device
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
if include_self == False:
raise AssertionError("include_self False for scatter reduce not yet supported")
for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
Expand All @@ -366,20 +406,53 @@ def scatter_reduce_decomposition(
dim,
index_slice,
torch.ones_like(src_slice),
MIN_ELE,
MAX_ELE,
include_self,
)
elif reduce == "amax":
reduceOp = ReduceOperation.AMAX
elif reduce == "amin":
reduceOp = ReduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
scatter_loop_tensor,
input_tensor,
dim,
index_slice,
src_slice,
MIN_ELE,
MAX_ELE,
include_self,
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
(
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
if include_self
else scatter_count_tensor
),
rounding_mode="trunc",
)
# for include_self cases for amax and amin additional processing is required
# except for the max elements in amax, rest are -inf or INT_MIN
# except for the min elements in amin, rest are +inf or INT_MAX
if reduce == "amax" and not (include_self):
# the relevant should be min, rest original
return torch.max(
scatter_loop_tensor,
torch.scatter(
input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
),
)
if reduce == "amin" and not (include_self):
# the relevant should be min, rest original
return torch.min(
scatter_loop_tensor,
torch.scatter(
input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
),
)
return scatter_loop_tensor


Expand Down
Loading

0 comments on commit 47fec01

Please sign in to comment.