From 096c5c64d5f3932f575f6df7c70879f5458c7e7f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 2 Dec 2024 09:41:56 +0100 Subject: [PATCH] minor edit --- .../dace_fieldview/gtir_python_codegen.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 9148284398..95b7ce5213 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence +from typing import Any, Callable import numpy as np @@ -120,37 +120,27 @@ class PythonCodegen(codegen.TemplatedGenerator): Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0], symbol_mapping=symbol_mapping) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_lambda( - self, - node: gtir.Lambda, - node_args: Sequence[gtir.Node], - symbol_mapping: dict[str, gtir.Node], - ) -> str: - symbol_mapping |= {param.id: arg for param, arg in zip(node.params, node_args)} - return self.visit(node.expr, symbol_mapping=symbol_mapping) - - def visit_FunCall(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: if isinstance(node.fun, gtir.Lambda): - return self._visit_lambda(node.fun, node.args, symbol_mapping=symbol_mapping) + # update the mapping from lambda parameters to corresponding argument expressions + args_map |= {p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)} + return self.visit(node.fun.expr, args_map=args_map) elif cpm.is_call_to(node, "deref"): - return self._visit_deref(node, symbol_mapping=symbol_mapping) + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args, symbol_mapping=symbol_mapping) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - def visit_SymRef(self, node: gtir.SymRef, symbol_mapping: dict[str, gtir.Node]) -> str: + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: symbol = str(node.id) - if symbol_mapping and symbol in symbol_mapping: - mapped_node = symbol_mapping[symbol] - return self.visit(mapped_node, symbol_mapping=symbol_mapping) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) return symbol @@ -158,7 +148,9 @@ def get_source(node: gtir.Node) -> str: """ Specialized visit method for symbolic expressions. + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + Returns: A string containing the Python code corresponding to a symbolic expression """ - return PythonCodegen.apply(node, symbol_mapping={}) + return PythonCodegen.apply(node, args_map={})