Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter reduce lowering with include_self=False #3153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 81 additions & 13 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,21 +303,30 @@ def __new__(cls, description: Any, func: Any) -> Any:
obj.func = func
return obj

def reduce_operation_with_scatter(
def reduce_operation_with_scatter_include_self(
self,
operation_lhs: Any,
initial_tensor: torch.Tensor,
dim: int,
index_tensor: torch.Tensor,
src_tensor: torch.Tensor,
) -> Any:
operation_lhs,
initial_tensor,
dim,
index_tensor,
src_tensor,
min_ele=float("-inf"),
max_ele=float("inf"),
include_self=True,
):
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == ReduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
elif self == ReduceOperation.AMAX:
scatter_tensor = initial_tensor
if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, min_ele)
elif self == ReduceOperation.AMIN:
scatter_tensor = initial_tensor
if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, max_ele)
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")
Expand All @@ -341,13 +350,39 @@ def scatter_reduce_decomposition(
include_self: bool = True,
) -> torch.Tensor:
scatter_loop_tensor = input_tensor
MAX_ELE = 0
MIN_ELE = 0
if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
MAX_ELE = 2147483647
MIN_ELE = -2147483648
else:
MAX_ELE = float("inf")
MIN_ELE = float("-inf")
if not (include_self):
if reduce == "sum" or reduce == "mean":
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
)
if reduce == "prod":
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
)
if reduce == "amax":
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, src_red_tensor
)
if reduce == "amin":
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
scatter_loop_tensor = torch.scatter(
scatter_loop_tensor, dim, index, src_red_tensor
)

device_input_tensor = input_tensor.device
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
if not include_self:
raise AssertionError("include_self False for scatter reduce not yet supported")
for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
Expand All @@ -371,20 +406,53 @@ def scatter_reduce_decomposition(
dim,
index_slice,
torch.ones_like(src_slice),
MIN_ELE,
MAX_ELE,
include_self,
)
elif reduce == "amax":
reduceOp = ReduceOperation.AMAX
elif reduce == "amin":
reduceOp = ReduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
scatter_loop_tensor,
input_tensor,
dim,
index_slice,
src_slice,
MIN_ELE,
MAX_ELE,
include_self,
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
(
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
if include_self
else scatter_count_tensor
),
rounding_mode="trunc",
)
# for include_self cases for amax and amin additional processing is required
# except for the max elements in amax, rest are -inf or INT_MIN
# except for the min elements in amin, rest are +inf or INT_MAX
if reduce == "amax" and not (include_self):
# the relevant should be min, rest original
return torch.max(
scatter_loop_tensor,
torch.scatter(
input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
),
)
if reduce == "amin" and not (include_self):
# the relevant should be min, rest original
return torch.min(
scatter_loop_tensor,
torch.scatter(
input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
),
)
return scatter_loop_tensor


Expand Down
Loading
Loading