From 716fd8d02903c192efdf4d4618b0056de7d1ab54 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 Nov 2024 16:30:42 +0100 Subject: [PATCH 1/5] Call ConstantFolding after create_global_tmps --- src/gt4py/next/iterator/transforms/pass_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec6f89685a..9ba90ff10a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -120,6 +120,8 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program + # create_global_tmps calls domain inference which creates domain unions, i.e., minima and maxima + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can From 0bca033ba57ab97811b1e24b51cbda92fbc325ab Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 Nov 2024 16:58:27 +0100 Subject: [PATCH 2/5] Move ConstantFolding into global_tmps --- src/gt4py/next/iterator/transforms/global_tmps.py | 15 +++++++++------ .../next/iterator/transforms/pass_manager.py | 2 -- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..feca153e31 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -20,6 +20,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -200,10 +201,12 @@ def create_global_tmps( assert isinstance(stmt, itir.SetAt) new_body.extend(_transform_stmt(stmt, uids=uids, declarations=declarations)) - return itir.Program( - id=program.id, - function_definitions=program.function_definitions, - params=program.params, - declarations=declarations, - body=new_body, + return ConstantFolding.apply( # type: ignore[return-value] # returns same type as input + itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=declarations, + body=new_body, + ) ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 9ba90ff10a..ec6f89685a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -120,8 +120,6 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program - # create_global_tmps calls domain inference which creates domain unions, i.e., minima and maxima - ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can From 55f9cdb5e01a9f8b68192f763886750e2bc5ea2d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 Nov 2024 18:04:58 +0100 Subject: [PATCH 3/5] Move ConstantFolding into infer_domain --- .../next/iterator/transforms/global_tmps.py | 15 ++++++--------- .../next/iterator/transforms/infer_domain.py | 17 +++++++++++------ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index feca153e31..a6d39883e3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -20,7 +20,6 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -201,12 +200,10 @@ def create_global_tmps( assert isinstance(stmt, itir.SetAt) new_body.extend(_transform_stmt(stmt, uids=uids, declarations=declarations)) - return ConstantFolding.apply( # type: ignore[return-value] # returns same type as input - itir.Program( - id=program.id, - function_definitions=program.function_definitions, - params=program.params, - declarations=declarations, - body=new_body, - ) + return itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=declarations, + body=new_body, ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..cbff44cecf 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -22,6 +22,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.utils import flatten_nested_tuple, tree_map @@ -418,10 +419,14 @@ def infer_program( not program.function_definitions ), "Domain propagation does not support function definitions." - return itir.Program( - id=program.id, - function_definitions=program.function_definitions, - params=program.params, - declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + return ConstantFolding.apply( # type: ignore[return-value] # returns same type as input + itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=program.declarations, + body=[ + _infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body + ], + ) ) From e8d64f26afccb54ea368b45e322aa2712b8cc372 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 15:19:36 +0100 Subject: [PATCH 4/5] Move ConstantFolding to domain_union --- .../next/iterator/ir_utils/domain_utils.py | 3 +++ .../next/iterator/transforms/infer_domain.py | 17 ++++++----------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..109cfce083 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.FunCall new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index cbff44cecf..6852b47a7a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -22,7 +22,6 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.utils import flatten_nested_tuple, tree_map @@ -419,14 +418,10 @@ def infer_program( not program.function_definitions ), "Domain propagation does not support function definitions." - return ConstantFolding.apply( # type: ignore[return-value] # returns same type as input - itir.Program( - id=program.id, - function_definitions=program.function_definitions, - params=program.params, - declarations=program.declarations, - body=[ - _infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body - ], - ) + return itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=program.declarations, + body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], ) From 5afd944a4a353e48c493059b3433f1b708687636 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:17:18 +0100 Subject: [PATCH 5/5] Update src/gt4py/next/iterator/ir_utils/domain_utils.py Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 109cfce083..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -170,7 +170,7 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: [domain.ranges[dim].stop for domain in domains], ) # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.FunCall + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges)