Skip to content

Commit

Permalink
bug[next]: ConstantFolding after create_global_tmps (#1756)
Browse files Browse the repository at this point in the history
Do `ConstantFolding` within `domain_union` to avoid  nested minima and maxima by `create_global_tmps`

---------

Co-authored-by: Till Ehrengruber <[email protected]>
  • Loading branch information
SF-N and tehrengruber authored Nov 30, 2024
1 parent 04513ba commit d581060
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding


def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]:
Expand Down Expand Up @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:
lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr),
[domain.ranges[dim].stop for domain in domains],
)
# constant fold expression to keep the tree small
start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)

0 comments on commit d581060

Please sign in to comment.