Skip to content

Commit

Permalink
feat[next]: GTIR temporary extraction pass (#1678)
Browse files Browse the repository at this point in the history
New temporary extraction pass. Transforms an `itir.Program` like
```
testee(inp, out) {
  out @ c⟨ IDimₕ: [0, 1) ⟩
       ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp));
}
```
into
```
testee(inp, out) {
  __tmp_1 = temporary(domain=c⟨ IDimₕ: [0, 1) ⟩, dtype=float64);
  __tmp_1 @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp);
  out @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(__tmp_1);
}
```
Note that this pass intentionally unconditionally extracts. In case you
don't want a temporary you should fuse the `as_fieldop` before. As such
the fusion pass (see #1670)
contains the heuristics on what to fuse.
  • Loading branch information
tehrengruber authored Oct 18, 2024
1 parent cb77ccb commit eb0a0c1
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 1,093 deletions.
13 changes: 7 additions & 6 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)
PRESERVED_ANNEX_ATTRS = ("type", "domain")

expr_map: dict[int, itir.SymRef]

Expand All @@ -43,15 +43,16 @@ def visit_Expr(self, node: itir.Node) -> itir.Node:

def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
node = cast(itir.FunCall, self.visit_Expr(node))
# TODO(tehrengruber): Use symbol name from the inner let, to increase readability of IR
# If we encounter an expression like:
# (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr)
# (non-recursively) inline the lambda to obtain:
# (λ(_cs_1) → _cs_1+_cs_1)(outer_expr)
# This allows identifying more common subexpressions later on
# In the CSE this allows identifying more common subexpressions later on. Other users
# of `extract_subexpression` (e.g. temporary extraction) can also rely on this to avoid
# the need to handle this artificial let-statements.
if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs"))
eligible_params = [isinstance(arg, itir.SymRef) for arg in node.args]
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
Expand Down Expand Up @@ -319,7 +320,7 @@ def extract_subexpression(
subexprs = CollectSubexpressions.apply(node)

# collect multiple occurrences and map them to fresh symbols
expr_map = dict[int, itir.SymRef]()
expr_map: dict[int, itir.SymRef] = {}
ignored_ids = set()
for expr, subexpr_entry in (
subexprs.items() if not deepest_expr_first else reversed(subexprs.items())
Expand Down
15 changes: 1 addition & 14 deletions src/gt4py/next/iterator/transforms/fencil_to_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import global_tmps


class FencilToProgram(eve.NodeTranslator):
@classmethod
def apply(
cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program
) -> itir.Program:
def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program:
return cls().visit(node)

def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt:
Expand All @@ -32,13 +29,3 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program:
body=self.visit(node.closures),
implicit_domain=node.implicit_domain,
)

def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program:
return itir.Program(
id=node.fencil.id,
function_definitions=node.fencil.function_definitions,
params=node.params,
declarations=node.tmps,
body=self.visit(node.fencil.closures),
implicit_domain=node.fencil.implicit_domain,
)
Loading

0 comments on commit eb0a0c1

Please sign in to comment.