Skip to content

Commit

Permalink
feat[next]: Add support for IfStmt in ITIR (#1664)
Browse files Browse the repository at this point in the history
Adds support for stateful if statements in ITIR for ITIR embedded and
GTFN backend.
  • Loading branch information
tehrengruber authored Oct 3, 2024
1 parent 4ecc69e commit 3b6261f
Show file tree
Hide file tree
Showing 16 changed files with 284 additions and 21 deletions.
20 changes: 20 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,26 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable
operators._tuple_assign_field(target, expr, common.domain(domain))


@runtime.if_stmt.register(EMBEDDED)
def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None:
"""
(Stateful) if statement.
The two branches are represented as lambda functions, such that they are not executed eagerly.
This is required to avoid out-of-bounds accesses. Note that a dedicated built-in is required,
contrary to using a plain python if-stmt, such that tracing / double roundtrip works.
Arguments:
cond: The condition to decide which branch to execute.
true_branch: A lambda function to be executed when `cond` is `True`.
false_branch: A lambda function to be executed when `cond` is `False`.
"""
if cond:
true_branch()
else:
false_branch()


def _compute_at_position(
sten: Callable,
ins: Sequence[common.Field],
Expand Down
7 changes: 7 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ class SetAt(Stmt): # from JAX array.at[...].set()
target: Expr # `make_tuple` or SymRef


class IfStmt(Stmt):
cond: Expr
true_branch: list[Stmt]
false_branch: list[Stmt]


class Temporary(Node):
id: Coerced[eve.SymbolName]
domain: Optional[Expr] = None
Expand Down Expand Up @@ -242,3 +248,4 @@ class Program(Node, ValidatedSymbolTableTrait):
FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign]
Program.__hash__ = Node.__hash__ # type: ignore[method-assign]
SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign]
IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign]
34 changes: 30 additions & 4 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Union

from lark import lark, lexer as lark_lexer, visitors as lark_visitors
from lark import lark, lexer as lark_lexer, tree as lark_tree, visitors as lark_visitors

from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
Expand All @@ -21,6 +21,7 @@
| declaration
| stencil_closure
| set_at
| if_stmt
| program
| prec0
Expand Down Expand Up @@ -78,13 +79,17 @@
| named_range
| "(" prec0 ")"
?stmt: set_at | if_stmt
set_at: prec0 "@" prec0 "←" prec1 ";"
else_branch_seperator: "else"
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}"
named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
set_at: prec0 "@" prec0 "←" prec1 ";"
fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}"
program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( set_at )+ "}"
program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( stmt )+ "}"
%import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS)
%ignore WS
Expand Down Expand Up @@ -215,6 +220,27 @@ def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure:
output, stencil, *inputs, domain = args
return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs)

def if_stmt(self, cond: ir.Expr, *args):
found_else_seperator = False
true_branch = []
false_branch = []
for arg in args:
if isinstance(arg, lark_tree.Tree):
assert arg.data == "else_branch_seperator"
found_else_seperator = True
continue

if not found_else_seperator:
true_branch.append(arg)
else:
false_branch.append(arg)

return ir.IfStmt(
cond=cond,
true_branch=true_branch,
false_branch=false_branch,
)

def declaration(self, *args: ir.Expr) -> ir.Temporary:
tid, domain, dtype = args
return ir.Temporary(id=tid, domain=domain, dtype=dtype)
Expand Down Expand Up @@ -253,7 +279,7 @@ def program(self, fid: str, *args: ir.Node) -> ir.Program:
elif isinstance(arg, ir.Temporary):
declarations.append(arg)
else:
assert isinstance(arg, ir.SetAt)
assert isinstance(arg, ir.Stmt)
body.append(arg)
return ir.Program(
id=fid,
Expand Down
13 changes: 13 additions & 0 deletions src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,19 @@ def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]:
)
return self._optimum(h, v)

def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]:
cond = self.visit(node.cond, prec=0)
true_branch = self._vmerge(*self.visit(node.true_branch, prec=0))
false_branch = self._vmerge(*self.visit(node.false_branch, prec=0))

hhead = self._hmerge(["if ("], cond, [") {"])
vhead = self._vmerge(["if ("], cond, [") {"])
head = self._optimum(hhead, vhead)

return self._vmerge(
head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"]
)

def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]:
assert prec == 0
function_definitions = self.visit(node.function_definitions, prec=0)
Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/next/iterator/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# TODO(tehrengruber): remove cirular dependency and import unconditionally
from gt4py.next import backend as next_backend

__all__ = ["offset", "fundef", "fendef", "closure", "set_at"]
__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"]


@dataclass(frozen=True)
Expand Down Expand Up @@ -214,3 +214,8 @@ def closure(*args): # TODO remove
@builtin_dispatch
def set_at(*args):
return BackendNotSelectedError()


@builtin_dispatch
def if_stmt(*args):
return BackendNotSelectedError()
21 changes: 21 additions & 0 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,27 @@ def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None:
TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target))


@iterator.runtime.if_stmt.register(TRACING)
def if_stmt(
cond: itir.Expr, true_branch_f: typing.Callable, false_branch_f: typing.Callable
) -> None:
true_branch: List[itir.Stmt] = []
false_branch: List[itir.Stmt] = []

old_body = TracerContext.body
TracerContext.body = true_branch
true_branch_f()

TracerContext.body = false_branch
false_branch_f()

TracerContext.body = old_body

TracerContext.add_stmt(
itir.IfStmt(cond=cond, true_branch=true_branch, false_branch=false_branch)
)


def _contains_tuple_dtype_field(arg):
if isinstance(arg, tuple):
return any(_contains_tuple_dtype_field(el) for el in arg)
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype
)

def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None:
cond = self.visit(node.cond, ctx=ctx)
assert cond == ts.ScalarType(kind=ts.ScalarKind.BOOL)
self.visit(node.true_branch, ctx=ctx)
self.visit(node.false_branch, ctx=ctx)

def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None:
self.visit(node.expr, ctx=ctx)
self.visit(node.domain, ctx=ctx)
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> str:
"{backend}.vertical_executor({axis})().{'.'.join('arg(' + a + ')' for a in args)}.{'.'.join(scans)}.execute();"
)

IfStmt = as_mako(
"""
if (${cond}) {
${'\\n'.join(true_branch)}
} else {
${'\\n'.join(false_branch)}
}
"""
)

ScanPassDefinition = as_mako(
"""
struct ${id} : ${'gtfn::fwd' if _this_node.forward else 'gtfn::bwd'} {
Expand Down
29 changes: 20 additions & 9 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,23 @@ class Backend(Node):
domain: Union[SymRef, CartesianDomain, UnstructuredDomain]


def _is_ref_or_tuple_expr_of_ref(expr: Expr) -> bool:
def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool:
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "tuple_get"
and len(expr.args) == 2
and _is_ref_or_tuple_expr_of_ref(expr.args[1])
and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1])
):
return True
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "make_tuple"
and all(_is_ref_or_tuple_expr_of_ref(arg) for arg in expr.args)
and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args)
):
return True
if isinstance(expr, SymRef):
if isinstance(expr, (SymRef, Literal)):
return True
return False

Expand All @@ -125,7 +125,8 @@ def _values_validator(
self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: list[Expr]
) -> None:
if not all(
isinstance(el, (SidFromScalar, SidComposite)) or _is_ref_or_tuple_expr_of_ref(el)
isinstance(el, (SidFromScalar, SidComposite))
or _is_ref_literal_or_tuple_expr_of_ref(el)
for el in value
):
raise ValueError(
Expand All @@ -140,11 +141,15 @@ class SidFromScalar(Expr):
def _arg_validator(
self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr
) -> None:
if not _is_ref_or_tuple_expr_of_ref(value):
if not _is_ref_literal_or_tuple_expr_of_ref(value):
raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.")


class StencilExecution(Node):
class Stmt(Node):
pass


class StencilExecution(Stmt):
backend: Backend
stencil: SymRef
output: Union[SymRef, SidComposite]
Expand All @@ -158,13 +163,19 @@ class Scan(Node):
init: Expr


class ScanExecution(Node):
class ScanExecution(Stmt):
backend: Backend
scans: list[Scan]
args: list[Expr]
axis: SymRef


class IfStmt(Stmt):
cond: Expr
true_branch: list[Stmt]
false_branch: list[Stmt]


class TemporaryAllocation(Node):
id: SymbolName
dtype: str
Expand Down Expand Up @@ -199,7 +210,7 @@ class Program(Node, ValidatedSymbolTableTrait):
function_definitions: list[
Union[FunctionDefinition, ScanPassDefinition, ImperativeFunctionDefinition]
]
executions: list[Union[StencilExecution, ScanExecution]]
executions: list[Stmt]
offset_definitions: list[TagDefinition]
grid_type: common.GridType
temporaries: list[TemporaryAllocation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting

def _process_regular_arguments(
self,
program: itir.FencilDefinition,
program: itir.FencilDefinition | itir.Program,
arg_types: tuple[ts.TypeSpec, ...],
offset_provider: dict[str, Connectivity | Dimension],
) -> tuple[list[interface.Parameter], list[str]]:
Expand Down Expand Up @@ -161,10 +161,10 @@ def _process_connectivity_args(

def _preprocess_program(
self,
program: itir.FencilDefinition,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
) -> itir.Program:
if not self.enable_itir_transforms:
if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms:
return fencil_to_program.FencilToProgram().apply(
program
) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR
Expand Down Expand Up @@ -195,7 +195,7 @@ def _preprocess_program(

def generate_stencil_source(
self,
program: itir.FencilDefinition,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
column_axis: Optional[common.Dimension],
) -> str:
Expand All @@ -216,7 +216,6 @@ def __call__(
) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]:
"""Generate GTFN C++ code from the ITIR definition."""
program: itir.FencilDefinition | itir.Program = inp.data
assert isinstance(program, itir.FencilDefinition)

# handle regular parameters and arguments of the program (i.e. what the user defined in
# the program)
Expand Down
15 changes: 13 additions & 2 deletions src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CastExpr,
FunCall,
FunctionDefinition,
IfStmt,
IntegralConstant,
Lambda,
Literal,
Expand Down Expand Up @@ -66,8 +67,11 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]:
_horizontal_dimension = "gtfn::unstructured::dim::horizontal"


def _get_domains(node: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]:
return eve_utils.xiter(node).if_isinstance(itir.SetAt).getattr("domain").to_set()
def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]:
result = set()
for node in nodes:
result |= node.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set()
return result


def _extract_grid_type(domain: itir.FunCall) -> common.GridType:
Expand Down Expand Up @@ -573,6 +577,13 @@ def remap_args(s: Scan) -> Scan:
def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None:
raise AssertionError("All Stmts need to be handled explicitly.")

def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt:
return IfStmt(
cond=self.visit(node.cond, **kwargs),
true_branch=self.visit(node.true_branch, **kwargs),
false_branch=self.visit(node.false_branch, **kwargs),
)

def visit_SetAt(
self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any
) -> Union[StencilExecution, ScanExecution]:
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def ${id}(${','.join(params)}):
"""
)
SetAt = as_mako("set_at(${expr}, ${domain}, ${target})")
IfStmt = as_mako("""if_stmt(${cond},
lambda: [${','.join(true_branch)}],
lambda: [${','.join(false_branch)}]
)""")

def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str:
assert (
Expand Down
Loading

0 comments on commit 3b6261f

Please sign in to comment.