From 7b5226841f220fda48c1153df8d93d3f9d2abcff Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 20 Nov 2023 17:17:53 +0100 Subject: [PATCH] Modify past node with grid sizes --- tools/src/icon4pytools/icon4pygen/metadata.py | 66 +++++++++++++++++-- tools/tests/icon4pygen/test_metadata.py | 64 +++++++++++++++++- 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/tools/src/icon4pytools/icon4pygen/metadata.py b/tools/src/icon4pytools/icon4pygen/metadata.py index fd9869016e..c720b4b78f 100644 --- a/tools/src/icon4pytools/icon4pygen/metadata.py +++ b/tools/src/icon4pytools/icon4pygen/metadata.py @@ -12,6 +12,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from __future__ import annotations +import copy import importlib import types from dataclasses import dataclass @@ -31,6 +32,7 @@ MultipleFieldOperatorException, ) from icon4pytools.icon4pygen.icochainsize import IcoChainSize +from gt4py.next.ffront import program_ast as past, dialect_ast_enums @dataclass(frozen=True) @@ -213,10 +215,10 @@ def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]: all_dims = set(i for j in all_field_types for i in j.dims) all_offset_labels = ( fvprog.itir.pre_walk_values() - .if_isinstance(itir.OffsetLiteral) - .getattr("value") - .if_isinstance(str) - .to_list() + .if_isinstance(itir.OffsetLiteral) + .getattr("value") + .if_isinstance(str) + .to_list() ) all_dim_labels = [dim.value for dim in all_dims if dim.kind == DimensionKind.LOCAL] @@ -225,12 +227,68 @@ def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]: return sorted_dims +def _create_symbol(symbol_name, symbol_type, starting_line, starting_column, filename) -> past.DataSymbol: + """ + Create a new DataSymbol instance for the given symbol name. + """ + location = eve.SourceLocation( + filename=filename, + line=starting_line, + end_line=starting_line, + column=starting_column, + end_column=starting_column + ) + return past.DataSymbol( + id=symbol_name, + type=symbol_type, + namespace=dialect_ast_enums.Namespace.LOCAL, + location=location + ) + + +def add_grid_element_sizes(past_node: past.Program) -> past.Program: + """ + Add grid element size parameters to the given program. + """ + symbol_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + + # Create new type definitions + new_program_type = copy.deepcopy(past_node.type) + new_types = {"num_cells": symbol_type, "num_edges": symbol_type, "num_vertices": symbol_type} + new_program_type.definition.pos_or_kw_args.update(new_types) + + symbols = ["num_cells", "num_edges", "num_vertices"] + last_param = past_node.params[-1] + new_params = list(past_node.params) # Clone the existing params list + + # Starting location for new symbols + symbol_line = last_param.location.line + symbol_column = last_param.location.column + fname = last_param.location.filename + + for i, symbol in enumerate(symbols, start=1): + param = _create_symbol(symbol, symbol_type, symbol_line + i, symbol_column, fname) + new_params.append(param) + + new_program = past.Program( + id=past_node.id, + type=new_program_type, + params=new_params, + body=past_node.body, + closure_vars=past_node.closure_vars, + location=past_node.location + ) + + return new_program + + def get_stencil_info( fencil_def: Program | FieldOperator | types.FunctionType | FendefDispatcher, is_global: bool = False, ) -> StencilInfo: """Generate StencilInfo dataclass from a fencil definition.""" fvprog = get_fvprog(fencil_def) + fvprog = add_grid_element_sizes(fvprog) offsets = scan_for_offsets(fvprog) itir = fvprog.itir fields = _get_field_infos(fvprog) diff --git a/tools/tests/icon4pygen/test_metadata.py b/tools/tests/icon4pygen/test_metadata.py index c23665a963..c9e40c2083 100644 --- a/tools/tests/icon4pygen/test_metadata.py +++ b/tools/tests/icon4pygen/test_metadata.py @@ -10,14 +10,20 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import copy +import numpy as np import pytest + from gt4py.next.common import Field from gt4py.next.ffront.decorator import field_operator, program -from icon4py.model.common.dimension import CellDim, KDim -from icon4pytools.icon4pygen.metadata import _get_field_infos, provide_neighbor_table +from icon4py.model.common.dimension import CellDim, EdgeDim, E2V, E2VDim, KDim, VertexDim +from icon4py.model.common.grid.simple import SimpleGrid +from icon4py.model.common.test_utils.helpers import random_field + +from icon4pytools.icon4pygen.metadata import _get_field_infos, provide_neighbor_table, add_grid_element_sizes chain_false_skipvalues = [ "C2E", @@ -136,3 +142,57 @@ def test_get_field_infos_does_not_contain_domain_args(program): assert field_info["b"].inp assert field_info["result"].out assert not field_info["result"].inp + + +@field_operator +def testee_op(a: Field[[VertexDim], float]) -> Field[[EdgeDim], float]: + amul = a * 2.0 + return amul(E2V[0]) + amul(E2V[1]) + + +@program +def prog( + a: Field[[VertexDim], float], + out: Field[[EdgeDim], float], +) -> Field[[EdgeDim], float]: + testee_op(a, out=out) + + +def reference(grid, a): + amul = a * 2.0 + return amul[grid.connectivities[E2VDim][:, 0]] + amul[grid.connectivities[E2VDim][:, 1]] + + +def test_add_grid_sizes(): + original_prog = copy.deepcopy(prog.past_node) + result = add_grid_element_sizes(prog.past_node) + + new_symbols = {"num_cells", "num_edges", "num_vertices"} + result_symbols = {param.id for param in result.params} + for symbol in new_symbols: + assert symbol in result_symbols, f"Symbol {symbol} is not in program params" + + type_symbols = set(result.type.definition.pos_or_kw_args.keys()) + for symbol in new_symbols: + assert symbol in type_symbols, f"Symbol {symbol} is not in new_program_type definition" + + assert result.id == original_prog.id, "Program ID has changed" + assert result.body == original_prog.body, "Program body has changed" + assert result.closure_vars == original_prog.closure_vars, "Program closure_vars have changed" + + +def test_stencil(): + grid = SimpleGrid() + a = random_field(grid, VertexDim) + out = random_field(grid, EdgeDim) + offset_provider = {"E2V": grid.get_offset_provider("E2V")} + + ref = reference(grid, a.ndarray) + + # without addition of size args + prog(a, out, offset_provider=offset_provider) + + # todo: add size args and test for equality with original program without size args + # use from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries_and_sizes + + assert np.allclose(ref, out)