Skip to content

Commit

Permalink
feat[next]: Add support for using Type Aliases (#1335)
Browse files Browse the repository at this point in the history
* Add Type Alias replacement pass + tests

* Fix: actual type not added in symbol list if already present

* Address requested changes

* Pre-commit fixes

* Address requested changes

* Prevent multiple float32 or float64 definitions in symtable

* pre-commit run changes and 'returns' arg type modifications

* Use 'from_type_hint' to avoid 'ScalarKind' construct

---------

Co-authored-by: Nina Burgdorfer <[email protected]>
  • Loading branch information
ninaburg and Nina Burgdorfer authored Oct 5, 2023
1 parent 54bca83 commit 0d821b1
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
105 changes: 105 additions & 0 deletions src/gt4py/next/ffront/foast_passes/type_alias_replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from typing import Any, cast

import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, traits
from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef
from gt4py.next.ffront import dialect_ast_enums
from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system.type_translation import from_type_hint


@dataclass
class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait):
"""
Replace Type Aliases with their actual type.
After this pass, the type aliases used for explicit construction of literal
values and for casting field values are replaced by their actual types.
"""

closure_vars: dict[str, Any]

@classmethod
def apply(
cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any]
) -> tuple[foast.FunctionDefinition, dict[str, Any]]:
foast_node = cls(closure_vars=closure_vars).visit(node)
new_closure_vars = closure_vars.copy()
for key, value in closure_vars.items():
if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES:
new_closure_vars[value.__name__] = closure_vars[key]
return foast_node, new_closure_vars

def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool:
return (
node_id in self.closure_vars
and isinstance(self.closure_vars[node_id], type)
and node_id not in TYPE_BUILTIN_NAMES
)

def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name:
if self.is_type_alias(node.id):
return foast.Name(
id=self.closure_vars[node.id].__name__, location=node.location, type=node.type
)
return node

def _update_closure_var_symbols(
self, closure_vars: list[foast.Symbol], location: SourceLocation
) -> list[foast.Symbol]:
new_closure_vars: list[foast.Symbol] = []
existing_type_names: set[str] = set()

for var in closure_vars:
if self.is_type_alias(var.id):
actual_type_name = self.closure_vars[var.id].__name__
# Avoid multiple definitions of a type in closure_vars
if actual_type_name not in existing_type_names:
new_closure_vars.append(
foast.Symbol(
id=actual_type_name,
type=ts.FunctionType(
pos_or_kw_args={},
kw_only_args={},
pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)],
returns=cast(
ts.DataType, from_type_hint(self.closure_vars[var.id])
),
),
namespace=dialect_ast_enums.Namespace.CLOSURE,
location=location,
)
)
existing_type_names.add(actual_type_name)
elif var.id not in existing_type_names:
new_closure_vars.append(var)
existing_type_names.add(var.id)

return new_closure_vars

def visit_FunctionDefinition(
self, node: foast.FunctionDefinition, **kwargs
) -> foast.FunctionDefinition:
return foast.FunctionDefinition(
id=node.id,
params=node.params,
body=self.visit(node.body, **kwargs),
closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location),
location=node.location,
)
2 changes: 2 additions & 0 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction
from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination
from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass
from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation

Expand Down Expand Up @@ -91,6 +92,7 @@ def _postprocess_dialect_ast(
closure_vars: dict[str, Any],
annotations: dict[str, Any],
) -> foast.FunctionDefinition:
foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars)
foast_node = ClosureVarFolding.apply(foast_node, closure_vars)
foast_node = DeadClosureVarElimination.apply(foast_node)
foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import ast
import typing
from typing import TypeAlias

import pytest

import gt4py.next as gtx
from gt4py.next import float32, float64
from gt4py.next.ffront.fbuiltins import astype
from gt4py.next.ffront.func_to_foast import FieldOperatorParser


TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests.
vpfloat: TypeAlias = float32
wpfloat: TypeAlias = float64


@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")])
def test_type_alias_replacement(test_input, expected):
def fieldop_with_typealias(
a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32]
) -> gtx.Field[[TDim], test_input]:
return test_input("3.1418") + astype(a, test_input)

foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias)

assert (
foast_tree.body.stmts[0].value.left.func.id == expected
and foast_tree.body.stmts[0].value.right.args[1].id == expected
)

0 comments on commit 0d821b1

Please sign in to comment.