Skip to content

Commit

Permalink
[dace] Cleanup shift visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Oct 25, 2023
1 parent b20ec2c commit 289cd75
Showing 1 changed file with 30 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -478,17 +478,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr:
return self._visit_deref(node)
if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef):
if node.fun.fun.id == "shift":
offset = node.fun.args[0]
assert isinstance(offset, itir.OffsetLiteral)
offset_name = offset.value
assert isinstance(offset_name, str)
if offset_name not in self.offset_provider:
raise ValueError(f"offset provider for `{offset_name}` is missing")
offset_provider = self.offset_provider[offset_name]
if isinstance(offset_provider, Dimension):
return self._visit_direct_addressing(node)
else:
return self._visit_indirect_addressing(node)
return self._visit_shift(node)
elif node.fun.fun.id == "reduce":
return self._visit_reduce(node)

Expand Down Expand Up @@ -653,39 +643,7 @@ def _make_shift_for_rest(self, rest, iterator):
fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator]
)

def _visit_direct_addressing(self, node: itir.FunCall) -> IteratorExpr:
assert isinstance(node.fun, itir.FunCall)
shift = node.fun
assert isinstance(shift, itir.FunCall)

tail, rest = self._split_shift_args(shift.args)
if rest:
iterator = self.visit(self._make_shift_for_rest(rest, node.args[0]))
else:
iterator = self.visit(node.args[0])

assert isinstance(tail[0], itir.OffsetLiteral)
offset = tail[0].value
assert isinstance(offset, str)
shifted_dim = self.offset_provider[offset].value

assert isinstance(tail[1], itir.OffsetLiteral)
shift_amount = tail[1].value
assert isinstance(shift_amount, int)

args = [ValueExpr(iterator.indices[shifted_dim], dace.int64)]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]} + {shift_amount}"
shifted_value = self.add_expr_tasklet(
list(zip(args, internals)), expr, dace.dtypes.int64, "dir_addr"
)[0].value

shifted_index = {dim: value for dim, value in iterator.indices.items()}
shifted_index[shifted_dim] = shifted_value

return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions)

def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr:
def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
shift = node.fun
assert isinstance(shift, itir.FunCall)
tail, rest = self._split_shift_args(shift.args)
Expand All @@ -699,53 +657,42 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr:
assert isinstance(offset, str)

if isinstance(tail[1], itir.OffsetLiteral):
element = tail[1].value
assert isinstance(element, int)
element_var = unique_var_name()
self.context.body.add_scalar(element_var, dace.dtypes.int64, transient=True)
element_node = self.context.state.add_access(element_var)
tlet_node = self.context.state.add_tasklet(
"get_element", {}, {"__out"}, f"__out = {element}"
)
self.context.state.add_edge(
tlet_node, "__out", element_node, None, dace.Memlet.simple(element_var, "0")
)
offset_node = self.visit_OffsetLiteral(tail[1])
else:
element_node = self.visit(tail[1])[0]
offset_node = self.visit(tail[1])[0]

if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider):
table = self.offset_provider[offset]
shifted_dim = table.origin_axis.value
target_dim = table.neighbor_axis.value

conn = self.context.state.add_access(connectivity_identifier(offset))
offset_provider = self.offset_provider[offset]
connectivity = self.context.state.add_access(connectivity_identifier(offset))

shifted_dim = offset_provider.origin_axis.value
target_dim = offset_provider.neighbor_axis.value
args = [
ValueExpr(conn, table.table.dtype),
ValueExpr(connectivity, offset_provider.table.dtype),
ValueExpr(iterator.indices[shifted_dim], dace.int64),
ValueExpr(element_node, dace.int64),
offset_node,
]

internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{internals[1]}, {internals[2]}]"
elif isinstance(self.offset_provider[offset], StridedNeighborOffsetProvider):
offset_provider = self.offset_provider[offset]

shifted_dim = offset_provider.origin_axis.value
target_dim = offset_provider.neighbor_axis.value
offset_value = iterator.indices[shifted_dim]
args = [
ValueExpr(offset_value, dace.int64),
ValueExpr(element_node, dace.int64),
ValueExpr(iterator.indices[shifted_dim], dace.int64),
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}"
else:
assert isinstance(self.offset_provider[offset], Dimension)
shifted_dim = target_dim = self.offset_provider[offset].value

shifted_dim = self.offset_provider[offset].value
target_dim = shifted_dim
args = [
ValueExpr(iterator.indices[shifted_dim], dace.int64),
element_node,
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]} + {internals[1]}"
Expand All @@ -760,6 +707,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr:

return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions)

def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> ValueExpr:
offset = node.value
assert isinstance(offset, int)
offset_var = unique_var_name()
self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True)
offset_node = self.context.state.add_access(offset_var)
tasklet_node = self.context.state.add_tasklet(
"get_offset", {}, {"__out"}, f"__out = {offset}"
)
self.context.state.add_edge(
tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0")
)
return ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)

def _visit_reduce(self, node: itir.FunCall):
result_name = unique_var_name()
result_access = self.context.state.add_access(result_name)
Expand Down

0 comments on commit 289cd75

Please sign in to comment.