-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat[next]: Add support for using Type Aliases (#1335)
* Add Type Alias replacement pass + tests * Fix: actual type not added in symbol list if already present * Address requested changes * Pre-commit fixes * Address requested changes * Prevent multiple float32 or float64 definitions in symtable * pre-commit run changes and 'returns' arg type modifications * Use 'from_type_hint' to avoid 'ScalarKind' construct --------- Co-authored-by: Nina Burgdorfer <[email protected]>
- Loading branch information
Showing
3 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
src/gt4py/next/ffront/foast_passes/type_alias_replacement.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, cast | ||
|
||
import gt4py.next.ffront.field_operator_ast as foast | ||
from gt4py.eve import NodeTranslator, traits | ||
from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef | ||
from gt4py.next.ffront import dialect_ast_enums | ||
from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES | ||
from gt4py.next.type_system import type_specifications as ts | ||
from gt4py.next.type_system.type_translation import from_type_hint | ||
|
||
|
||
@dataclass | ||
class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): | ||
""" | ||
Replace Type Aliases with their actual type. | ||
After this pass, the type aliases used for explicit construction of literal | ||
values and for casting field values are replaced by their actual types. | ||
""" | ||
|
||
closure_vars: dict[str, Any] | ||
|
||
@classmethod | ||
def apply( | ||
cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] | ||
) -> tuple[foast.FunctionDefinition, dict[str, Any]]: | ||
foast_node = cls(closure_vars=closure_vars).visit(node) | ||
new_closure_vars = closure_vars.copy() | ||
for key, value in closure_vars.items(): | ||
if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES: | ||
new_closure_vars[value.__name__] = closure_vars[key] | ||
return foast_node, new_closure_vars | ||
|
||
def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: | ||
return ( | ||
node_id in self.closure_vars | ||
and isinstance(self.closure_vars[node_id], type) | ||
and node_id not in TYPE_BUILTIN_NAMES | ||
) | ||
|
||
def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: | ||
if self.is_type_alias(node.id): | ||
return foast.Name( | ||
id=self.closure_vars[node.id].__name__, location=node.location, type=node.type | ||
) | ||
return node | ||
|
||
def _update_closure_var_symbols( | ||
self, closure_vars: list[foast.Symbol], location: SourceLocation | ||
) -> list[foast.Symbol]: | ||
new_closure_vars: list[foast.Symbol] = [] | ||
existing_type_names: set[str] = set() | ||
|
||
for var in closure_vars: | ||
if self.is_type_alias(var.id): | ||
actual_type_name = self.closure_vars[var.id].__name__ | ||
# Avoid multiple definitions of a type in closure_vars | ||
if actual_type_name not in existing_type_names: | ||
new_closure_vars.append( | ||
foast.Symbol( | ||
id=actual_type_name, | ||
type=ts.FunctionType( | ||
pos_or_kw_args={}, | ||
kw_only_args={}, | ||
pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], | ||
returns=cast( | ||
ts.DataType, from_type_hint(self.closure_vars[var.id]) | ||
), | ||
), | ||
namespace=dialect_ast_enums.Namespace.CLOSURE, | ||
location=location, | ||
) | ||
) | ||
existing_type_names.add(actual_type_name) | ||
elif var.id not in existing_type_names: | ||
new_closure_vars.append(var) | ||
existing_type_names.add(var.id) | ||
|
||
return new_closure_vars | ||
|
||
def visit_FunctionDefinition( | ||
self, node: foast.FunctionDefinition, **kwargs | ||
) -> foast.FunctionDefinition: | ||
return foast.FunctionDefinition( | ||
id=node.id, | ||
params=node.params, | ||
body=self.visit(node.body, **kwargs), | ||
closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location), | ||
location=node.location, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
import ast | ||
import typing | ||
from typing import TypeAlias | ||
|
||
import pytest | ||
|
||
import gt4py.next as gtx | ||
from gt4py.next import float32, float64 | ||
from gt4py.next.ffront.fbuiltins import astype | ||
from gt4py.next.ffront.func_to_foast import FieldOperatorParser | ||
|
||
|
||
TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. | ||
vpfloat: TypeAlias = float32 | ||
wpfloat: TypeAlias = float64 | ||
|
||
|
||
@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")]) | ||
def test_type_alias_replacement(test_input, expected): | ||
def fieldop_with_typealias( | ||
a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32] | ||
) -> gtx.Field[[TDim], test_input]: | ||
return test_input("3.1418") + astype(a, test_input) | ||
|
||
foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias) | ||
|
||
assert ( | ||
foast_tree.body.stmts[0].value.left.func.id == expected | ||
and foast_tree.body.stmts[0].value.right.args[1].id == expected | ||
) |