diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index e0301859d3..ad990cfcba 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -60,6 +60,7 @@ class LoopBasedReplacement: def replaced_name(func_name: str) -> str: replacements = { "SUM": "__dace_sum", + "PRODUCT": "__dace_product", "ANY": "__dace_any", "ALL": "__dace_all", "COUNT": "__dace_count" @@ -70,6 +71,7 @@ def replaced_name(func_name: str) -> str: def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: func_types = { "__dace_sum": "DOUBLE", + "__dace_product": "DOUBLE", "__dace_any": "INTEGER", "__dace_all": "INTEGER", "__dace_count": "INTEGER" @@ -206,6 +208,72 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No self.count = self.count + range_index return ast_internal_classes.Execution_Part_Node(execution=newbody) +class SumProduct(LoopBasedReplacementTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _initialize(self): + self.rvals = [] + self.argument_variable = None + + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + + for arg in node.args: + + # supports syntax SUM(arr) + if isinstance(arg, ast_internal_classes.Name_Node): + array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent) + array_node.name = arg + + # If we access SUM(arr) where arr has many dimensions, + # We need to create a ParDecl_Node for each dimension + dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes) + array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims + + self.rvals.append(array_node) + + # supports syntax SUM(arr(:)) + elif isinstance(arg, ast_internal_classes.Array_Subscript_Node): + self.rvals.append(arg) + + else: + raise NotImplementedError("We do not support non-array arguments for SUM/PRODUCT") + + + def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + + if len(self.rvals) != 1: + raise NotImplementedError("Only one array can be summed") + + self.argument_variable = self.rvals[0] + + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + + def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=ast_internal_classes.Int_Literal_Node(value=self._result_init_value()), + line_number=node.line_number + ) + + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=node.lval, + op=self._result_update_op(), + rval=self.argument_variable, + line_number=node.line_number + ), + line_number=node.line_number + ) + + class Sum(LoopBasedReplacement): """ @@ -217,7 +285,7 @@ class Sum(LoopBasedReplacement): Then, we generate a binary node accumulating the result. """ - class Transformation(LoopBasedReplacementTransformation): + class Transformation(SumProduct): def __init__(self, ast): super().__init__(ast) @@ -225,62 +293,36 @@ def __init__(self, ast): def func_name(self) -> str: return "__dace_sum" - def _initialize(self): - self.rvals = [] - self.argument_variable = None - - def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): - - for arg in node.args: - - # supports syntax SUM(arr) - if isinstance(arg, ast_internal_classes.Name_Node): - array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent) - array_node.name = arg - - # If we access SUM(arr) where arr has many dimensions, - # We need to create a ParDecl_Node for each dimension - dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes) - array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims - - self.rvals.append(array_node) - - # supports syntax SUM(arr(:)) - if isinstance(arg, ast_internal_classes.Array_Subscript_Node): - self.rvals.append(arg) + def _result_init_value(self): + return "0" + def _result_update_op(self): + return "+" - def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): +class Product(LoopBasedReplacement): - if len(self.rvals) != 1: - raise NotImplementedError("Only one array can be summed") + """ + In this class, we implement the transformation for Fortran intrinsic PRODUCT(:) + We support two ways of invoking the function - by providing array name and array subscript. + We do NOT support the *DIM* and *MASK* arguments. - self.argument_variable = self.rvals[0] + During the loop construction, we add a single variable storing the partial result. + Then, we generate a binary node accumulating the result. + """ - par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + class Transformation(SumProduct): - def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + def __init__(self, ast): + super().__init__(ast) - return ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=ast_internal_classes.Int_Literal_Node(value="0"), - line_number=node.line_number - ) + def func_name(self) -> str: + return "__dace_product" - def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + def _result_init_value(self): + return "1" - return ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=ast_internal_classes.BinOp_Node( - lval=node.lval, - op="+", - rval=self.argument_variable, - line_number=node.line_number - ), - line_number=node.line_number - ) + def _result_update_op(self): + return "*" class AnyAllCountTransformation(LoopBasedReplacementTransformation): @@ -590,6 +632,7 @@ class FortranIntrinsics: "SELECTED_INT_KIND": SelectedKind, "SELECTED_REAL_KIND": SelectedKind, "SUM": Sum, + "PRODUCT": Product, "ANY": Any, "COUNT": Count, "ALL": All @@ -599,6 +642,7 @@ class FortranIntrinsics: "__dace_selected_int_kind": SelectedKind, "__dace_selected_real_kind": SelectedKind, "__dace_sum": Sum, + "__dace_product": Product, "__dace_any": Any, "__dace_all": All, "__dace_count": Count diff --git a/tests/fortran/intrinsic_product.py b/tests/fortran/intrinsic_product.py new file mode 100644 index 0000000000..06d14e0a34 --- /dev/null +++ b/tests/fortran/intrinsic_product.py @@ -0,0 +1,118 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_product_array(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(3) :: res + CALL index_test_function(d, res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(3) :: res + + res(1) = PRODUCT(d) + res(2) = PRODUCT(d(:)) + res(3) = PRODUCT(d(2:5)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + print(d) + print(res) + assert res[0] == np.prod(d) + assert res[1] == np.prod(d) + assert res[2] == np.prod(d[1:5]) + +def test_fortran_frontend_product_array_dim(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_count_test_function(d, res) + end + + SUBROUTINE intrinsic_count_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = PRODUCT(d, 1) + + END SUBROUTINE intrinsic_count_test_function + """ + + with pytest.raises(NotImplementedError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + +def test_fortran_frontend_product_2d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + + res(1) = PRODUCT(d) + res(2) = PRODUCT(d(:,:)) + res(3) = PRODUCT(d(2:4, 2)) + res(4) = PRODUCT(d(2:4, 2:3)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 3] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 1 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == np.prod(d) + assert res[1] == np.prod(d) + assert res[2] == np.prod(d[1:4, 1]) + assert res[3] == np.prod(d[1:4, 1:3]) + +if __name__ == "__main__": + + test_fortran_frontend_product_array() + test_fortran_frontend_product_array_dim() + test_fortran_frontend_product_2d()