Skip to content

Commit

Permalink
Modify past node with grid sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Nov 20, 2023
1 parent 9d0614b commit 7b52268
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 6 deletions.
66 changes: 62 additions & 4 deletions tools/src/icon4pytools/icon4pygen/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations

import copy
import importlib
import types
from dataclasses import dataclass
Expand All @@ -31,6 +32,7 @@
MultipleFieldOperatorException,
)
from icon4pytools.icon4pygen.icochainsize import IcoChainSize
from gt4py.next.ffront import program_ast as past, dialect_ast_enums


@dataclass(frozen=True)
Expand Down Expand Up @@ -213,10 +215,10 @@ def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]:
all_dims = set(i for j in all_field_types for i in j.dims)
all_offset_labels = (
fvprog.itir.pre_walk_values()
.if_isinstance(itir.OffsetLiteral)
.getattr("value")
.if_isinstance(str)
.to_list()
.if_isinstance(itir.OffsetLiteral)
.getattr("value")
.if_isinstance(str)
.to_list()
)
all_dim_labels = [dim.value for dim in all_dims if dim.kind == DimensionKind.LOCAL]

Expand All @@ -225,12 +227,68 @@ def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]:
return sorted_dims


def _create_symbol(symbol_name, symbol_type, starting_line, starting_column, filename) -> past.DataSymbol:
"""
Create a new DataSymbol instance for the given symbol name.
"""
location = eve.SourceLocation(
filename=filename,
line=starting_line,
end_line=starting_line,
column=starting_column,
end_column=starting_column
)
return past.DataSymbol(
id=symbol_name,
type=symbol_type,
namespace=dialect_ast_enums.Namespace.LOCAL,
location=location
)


def add_grid_element_sizes(past_node: past.Program) -> past.Program:
"""
Add grid element size parameters to the given program.
"""
symbol_type = ts.ScalarType(kind=ts.ScalarKind.INT32)

# Create new type definitions
new_program_type = copy.deepcopy(past_node.type)
new_types = {"num_cells": symbol_type, "num_edges": symbol_type, "num_vertices": symbol_type}
new_program_type.definition.pos_or_kw_args.update(new_types)

symbols = ["num_cells", "num_edges", "num_vertices"]
last_param = past_node.params[-1]
new_params = list(past_node.params) # Clone the existing params list

# Starting location for new symbols
symbol_line = last_param.location.line
symbol_column = last_param.location.column
fname = last_param.location.filename

for i, symbol in enumerate(symbols, start=1):
param = _create_symbol(symbol, symbol_type, symbol_line + i, symbol_column, fname)
new_params.append(param)

new_program = past.Program(
id=past_node.id,
type=new_program_type,
params=new_params,
body=past_node.body,
closure_vars=past_node.closure_vars,
location=past_node.location
)

return new_program


def get_stencil_info(
fencil_def: Program | FieldOperator | types.FunctionType | FendefDispatcher,
is_global: bool = False,
) -> StencilInfo:
"""Generate StencilInfo dataclass from a fencil definition."""
fvprog = get_fvprog(fencil_def)
fvprog = add_grid_element_sizes(fvprog)
offsets = scan_for_offsets(fvprog)
itir = fvprog.itir
fields = _get_field_infos(fvprog)
Expand Down
64 changes: 62 additions & 2 deletions tools/tests/icon4pygen/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import copy

import numpy as np
import pytest

from gt4py.next.common import Field
from gt4py.next.ffront.decorator import field_operator, program
from icon4py.model.common.dimension import CellDim, KDim

from icon4pytools.icon4pygen.metadata import _get_field_infos, provide_neighbor_table
from icon4py.model.common.dimension import CellDim, EdgeDim, E2V, E2VDim, KDim, VertexDim
from icon4py.model.common.grid.simple import SimpleGrid
from icon4py.model.common.test_utils.helpers import random_field


from icon4pytools.icon4pygen.metadata import _get_field_infos, provide_neighbor_table, add_grid_element_sizes

chain_false_skipvalues = [
"C2E",
Expand Down Expand Up @@ -136,3 +142,57 @@ def test_get_field_infos_does_not_contain_domain_args(program):
assert field_info["b"].inp
assert field_info["result"].out
assert not field_info["result"].inp


@field_operator
def testee_op(a: Field[[VertexDim], float]) -> Field[[EdgeDim], float]:
amul = a * 2.0
return amul(E2V[0]) + amul(E2V[1])


@program
def prog(
a: Field[[VertexDim], float],
out: Field[[EdgeDim], float],
) -> Field[[EdgeDim], float]:
testee_op(a, out=out)


def reference(grid, a):
amul = a * 2.0
return amul[grid.connectivities[E2VDim][:, 0]] + amul[grid.connectivities[E2VDim][:, 1]]


def test_add_grid_sizes():
original_prog = copy.deepcopy(prog.past_node)
result = add_grid_element_sizes(prog.past_node)

new_symbols = {"num_cells", "num_edges", "num_vertices"}
result_symbols = {param.id for param in result.params}
for symbol in new_symbols:
assert symbol in result_symbols, f"Symbol {symbol} is not in program params"

type_symbols = set(result.type.definition.pos_or_kw_args.keys())
for symbol in new_symbols:
assert symbol in type_symbols, f"Symbol {symbol} is not in new_program_type definition"

assert result.id == original_prog.id, "Program ID has changed"
assert result.body == original_prog.body, "Program body has changed"
assert result.closure_vars == original_prog.closure_vars, "Program closure_vars have changed"


def test_stencil():
grid = SimpleGrid()
a = random_field(grid, VertexDim)
out = random_field(grid, EdgeDim)
offset_provider = {"E2V": grid.get_offset_provider("E2V")}

ref = reference(grid, a.ndarray)

# without addition of size args
prog(a, out, offset_provider=offset_provider)

# todo: add size args and test for equality with original program without size args
# use from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries_and_sizes

assert np.allclose(ref, out)

0 comments on commit 7b52268

Please sign in to comment.