Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 6, 2024
1 parent acf5ac0 commit a706b27
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import abc
import dataclasses
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias
from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias

import dace
from dace import subsets as sbs
Expand Down Expand Up @@ -219,17 +219,26 @@ def _parse_fieldop_arg(
) -> (
gtir_dataflow.IteratorExpr
| gtir_dataflow.MemletExpr
| tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr, ...]
| gtir_dataflow.ValueExpr
| tuple[
gtir_dataflow.IteratorExpr
| gtir_dataflow.MemletExpr
| gtir_dataflow.ValueExpr
| tuple[Any, ...],
...,
]
):
"""Helper method to visit an expression passed as argument to a field operator."""

arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state)

def get_arg_value(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr:
# In case of scan field operator, the arguments to the vertical stencil are passed by value.
def get_arg_value(
arg: FieldopData,
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
arg_expr = arg.get_local_view(domain)
if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr):
return arg_expr
# In case of scan field operator, the arguments to the vertical stencil are passed by value.
return gtir_dataflow.MemletExpr(
arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg)
)
Expand Down Expand Up @@ -277,7 +286,8 @@ def _create_field_operator(
node_type: ts.FieldType | ts.TupleType,
sdfg_builder: gtir_sdfg.SDFGBuilder,
input_edges: Iterable[gtir_dataflow.DataflowInputEdge],
output_edges: gtir_dataflow.DataflowOutputEdge | tuple[gtir_dataflow.DataflowOutputEdge, ...],
output_edges: gtir_dataflow.DataflowOutputEdge
| tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...],
scan_dim: Optional[gtx_common.Dimension] = None,
) -> FieldopResult:
"""
Expand Down Expand Up @@ -446,16 +456,18 @@ def translate_as_fieldop(

if cpm.is_call_to(fieldop_expr, "scan"):
return translate_scan(node, sdfg, state, sdfg_builder)
elif isinstance(fieldop_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
stencil_expr = fieldop_expr
elif cpm.is_ref_to(fieldop_expr, "deref"):

assert isinstance(node.type, ts.FieldType)
if cpm.is_ref_to(fieldop_expr, "deref"):
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
# value to 'as_fieldop' function. It results in broadcasting the scalar value
# over the field domain.
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype
elif isinstance(fieldop_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
stencil_expr = fieldop_expr
else:
raise NotImplementedError(
f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node."
Expand Down Expand Up @@ -835,6 +847,7 @@ def translate_scan(
) -> FieldopResult:
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, (ts.FieldType, ts.TupleType))

fun_node = node.fun
assert len(fun_node.args) == 2
Expand Down Expand Up @@ -886,7 +899,9 @@ def scan_output_name(input_name: str) -> str:

# create list of params to the lambda function with associated node type
lambda_symbols = {scan_state: scan_state_type} | {
str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True)
str(p.id): arg.type
for p, arg in zip(stencil_expr.params[1:], node.args, strict=True)
if isinstance(arg.type, ts.DataType)
}

# visit the arguments to be passed to the lambda expression
Expand All @@ -900,8 +915,8 @@ def scan_output_name(input_name: str) -> str:
}

# parse the dataflow input and output symbols
lambda_flat_args = {}
lambda_field_offsets = {}
lambda_flat_args: dict[str, FieldopData] = {}
lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {}
for param, arg in lambda_args_mapping.items():
tuple_fields = flatten_tuples(param, arg)
lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields}
Expand Down Expand Up @@ -986,9 +1001,10 @@ def init_scan_state(sym: gtir.Sym) -> None:
nsdfg.make_array_memlet(input_state),
)

init_scan_state(scan_state_input) if isinstance(
scan_state_input, FieldopData
) else gtx_utils.tree_map(init_scan_state)(scan_state_input)
if isinstance(scan_state_input, tuple):
gtx_utils.tree_map(init_scan_state)(scan_state_input)
else:
init_scan_state(scan_state_input)

# connect the dataflow input directly to the source data nodes, without passing through a map node;
# the reason is that the map for horizontal domain is outside the scan loop region
Expand All @@ -1000,8 +1016,8 @@ def connect_scan_output(
scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym
) -> FieldopData:
scan_result = scan_output_edge.result
assert isinstance(scan_result, gtir_dataflow.ValueExpr)
assert isinstance(sym.type, ts.ScalarType) and scan_result.gt_dtype == sym.type
assert isinstance(scan_result.gt_dtype, ts.ScalarType)
assert scan_result.gt_dtype == sym.type
scan_result_data = scan_result.dc_node.data
scan_result_desc = scan_result.dc_node.desc(nsdfg)

Expand All @@ -1023,12 +1039,17 @@ def connect_scan_output(
output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype)
return FieldopData(output_node, output_type, scan_output_offset)

if isinstance(scan_state_input, gtir.Sym):
assert isinstance(result, gtir_dataflow.DataflowOutputEdge)
lambda_output = connect_scan_output(result, scan_state_input)
else:
assert isinstance(result, tuple)
lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input)
lambda_output = (
gtx_utils.tree_map(connect_scan_output)(result, scan_state_input)
if (isinstance(result, tuple) and isinstance(scan_state_input, tuple))
else connect_scan_output(result, scan_state_input)
if (
isinstance(result, gtir_dataflow.DataflowOutputEdge)
and isinstance(scan_state_input, gtir.Sym)
)
else None
)
assert lambda_output

# in case tuples are passed as argument, isolated non-transient nodes might be left in the state,
# because not all tuple fields are necessarily accessed in the lambda scope
Expand Down Expand Up @@ -1095,15 +1116,14 @@ def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutpu
None,
dace.Memlet.from_array(output_data, output_desc),
)
output_expr = gtir_dataflow.MemletExpr(
output_node, scan_data.gt_type.dtype, sbs.Range.from_array(output_desc)
)
output_expr = gtir_dataflow.ValueExpr(output_node, scan_data.gt_type.dtype)
return gtir_dataflow.DataflowOutputEdge(state, output_expr)

if isinstance(lambda_output, FieldopData):
output_edges = construct_output_edge(lambda_output)
else:
output_edges = gtx_utils.tree_map(construct_output_edge)(lambda_output)
output_edges = (
construct_output_edge(lambda_output)
if isinstance(lambda_output, FieldopData)
else gtx_utils.tree_map(construct_output_edge)(lambda_output)
)

return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim
Expand Down
Loading

0 comments on commit a706b27

Please sign in to comment.