diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py new file mode 100644 index 0000000000..c5857999ee --- /dev/null +++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Any, cast + +import gt4py.next.ffront.field_operator_ast as foast +from gt4py.eve import NodeTranslator, traits +from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef +from gt4py.next.ffront import dialect_ast_enums +from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system.type_translation import from_type_hint + + +@dataclass +class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): + """ + Replace Type Aliases with their actual type. + + After this pass, the type aliases used for explicit construction of literal + values and for casting field values are replaced by their actual types. + """ + + closure_vars: dict[str, Any] + + @classmethod + def apply( + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] + ) -> tuple[foast.FunctionDefinition, dict[str, Any]]: + foast_node = cls(closure_vars=closure_vars).visit(node) + new_closure_vars = closure_vars.copy() + for key, value in closure_vars.items(): + if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES: + new_closure_vars[value.__name__] = closure_vars[key] + return foast_node, new_closure_vars + + def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: + return ( + node_id in self.closure_vars + and isinstance(self.closure_vars[node_id], type) + and node_id not in TYPE_BUILTIN_NAMES + ) + + def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: + if self.is_type_alias(node.id): + return foast.Name( + id=self.closure_vars[node.id].__name__, location=node.location, type=node.type + ) + return node + + def _update_closure_var_symbols( + self, closure_vars: list[foast.Symbol], location: SourceLocation + ) -> list[foast.Symbol]: + new_closure_vars: list[foast.Symbol] = [] + existing_type_names: set[str] = set() + + for var in closure_vars: + if self.is_type_alias(var.id): + actual_type_name = self.closure_vars[var.id].__name__ + # Avoid multiple definitions of a type in closure_vars + if actual_type_name not in existing_type_names: + new_closure_vars.append( + foast.Symbol( + id=actual_type_name, + type=ts.FunctionType( + pos_or_kw_args={}, + kw_only_args={}, + pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], + returns=cast( + ts.DataType, from_type_hint(self.closure_vars[var.id]) + ), + ), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=location, + ) + ) + existing_type_names.add(actual_type_name) + elif var.id not in existing_type_names: + new_closure_vars.append(var) + existing_type_names.add(var.id) + + return new_closure_vars + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs + ) -> foast.FunctionDefinition: + return foast.FunctionDefinition( + id=node.id, + params=node.params, + body=self.visit(node.body, **kwargs), + closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location), + location=node.location, + ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 082939c938..c7c4c3a23f 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -33,6 +33,7 @@ from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass +from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -91,6 +92,7 @@ def _postprocess_dialect_ast( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> foast.FunctionDefinition: + foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars) foast_node = ClosureVarFolding.apply(foast_node, closure_vars) foast_node = DeadClosureVarElimination.apply(foast_node) foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars) diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py new file mode 100644 index 0000000000..e87f869352 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import typing +from typing import TypeAlias + +import pytest + +import gt4py.next as gtx +from gt4py.next import float32, float64 +from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.func_to_foast import FieldOperatorParser + + +TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +vpfloat: TypeAlias = float32 +wpfloat: TypeAlias = float64 + + +@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")]) +def test_type_alias_replacement(test_input, expected): + def fieldop_with_typealias( + a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32] + ) -> gtx.Field[[TDim], test_input]: + return test_input("3.1418") + astype(a, test_input) + + foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias) + + assert ( + foast_tree.body.stmts[0].value.left.func.id == expected + and foast_tree.body.stmts[0].value.right.args[1].id == expected + )