Skip to content

Commit

Permalink
Implement Fortran intrinsic PRODUCT
Browse files Browse the repository at this point in the history
  • Loading branch information
mcopik committed Oct 18, 2023
1 parent 827bd1f commit 8fd8ba2
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 48 deletions.
140 changes: 92 additions & 48 deletions dace/frontend/fortran/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class LoopBasedReplacement:
def replaced_name(func_name: str) -> str:
replacements = {
"SUM": "__dace_sum",
"PRODUCT": "__dace_product",
"ANY": "__dace_any",
"ALL": "__dace_all",
"COUNT": "__dace_count"
Expand All @@ -70,6 +71,7 @@ def replaced_name(func_name: str) -> str:
def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode:
func_types = {
"__dace_sum": "DOUBLE",
"__dace_product": "DOUBLE",
"__dace_any": "INTEGER",
"__dace_all": "INTEGER",
"__dace_count": "INTEGER"
Expand Down Expand Up @@ -206,6 +208,72 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
self.count = self.count + range_index
return ast_internal_classes.Execution_Part_Node(execution=newbody)

class SumProduct(LoopBasedReplacementTransformation):

def __init__(self, ast):
super().__init__(ast)

def _initialize(self):
self.rvals = []
self.argument_variable = None

def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node):

for arg in node.args:

# supports syntax SUM(arr)
if isinstance(arg, ast_internal_classes.Name_Node):
array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent)
array_node.name = arg

# If we access SUM(arr) where arr has many dimensions,
# We need to create a ParDecl_Node for each dimension
dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes)
array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims

self.rvals.append(array_node)

# supports syntax SUM(arr(:))
elif isinstance(arg, ast_internal_classes.Array_Subscript_Node):
self.rvals.append(arg)

else:
raise NotImplementedError("We do not support non-array arguments for SUM/PRODUCT")


def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]):

if len(self.rvals) != 1:
raise NotImplementedError("Only one array can be summed")

self.argument_variable = self.rvals[0]

par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)

def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:

return ast_internal_classes.BinOp_Node(
lval=node.lval,
op="=",
rval=ast_internal_classes.Int_Literal_Node(value=self._result_init_value()),
line_number=node.line_number
)

def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:

return ast_internal_classes.BinOp_Node(
lval=node.lval,
op="=",
rval=ast_internal_classes.BinOp_Node(
lval=node.lval,
op=self._result_update_op(),
rval=self.argument_variable,
line_number=node.line_number
),
line_number=node.line_number
)


class Sum(LoopBasedReplacement):

"""
Expand All @@ -217,70 +285,44 @@ class Sum(LoopBasedReplacement):
Then, we generate a binary node accumulating the result.
"""

class Transformation(LoopBasedReplacementTransformation):
class Transformation(SumProduct):

def __init__(self, ast):
super().__init__(ast)

def func_name(self) -> str:
return "__dace_sum"

def _initialize(self):
self.rvals = []
self.argument_variable = None

def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node):

for arg in node.args:

# supports syntax SUM(arr)
if isinstance(arg, ast_internal_classes.Name_Node):
array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent)
array_node.name = arg

# If we access SUM(arr) where arr has many dimensions,
# We need to create a ParDecl_Node for each dimension
dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes)
array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims

self.rvals.append(array_node)

# supports syntax SUM(arr(:))
if isinstance(arg, ast_internal_classes.Array_Subscript_Node):
self.rvals.append(arg)
def _result_init_value(self):
return "0"

def _result_update_op(self):
return "+"

def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]):
class Product(LoopBasedReplacement):

if len(self.rvals) != 1:
raise NotImplementedError("Only one array can be summed")
"""
In this class, we implement the transformation for Fortran intrinsic PRODUCT(:)
We support two ways of invoking the function - by providing array name and array subscript.
We do NOT support the *DIM* and *MASK* arguments.
self.argument_variable = self.rvals[0]
During the loop construction, we add a single variable storing the partial result.
Then, we generate a binary node accumulating the result.
"""

par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)
class Transformation(SumProduct):

def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:
def __init__(self, ast):
super().__init__(ast)

return ast_internal_classes.BinOp_Node(
lval=node.lval,
op="=",
rval=ast_internal_classes.Int_Literal_Node(value="0"),
line_number=node.line_number
)
def func_name(self) -> str:
return "__dace_product"

def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:
def _result_init_value(self):
return "1"

return ast_internal_classes.BinOp_Node(
lval=node.lval,
op="=",
rval=ast_internal_classes.BinOp_Node(
lval=node.lval,
op="+",
rval=self.argument_variable,
line_number=node.line_number
),
line_number=node.line_number
)
def _result_update_op(self):
return "*"

class AnyAllCountTransformation(LoopBasedReplacementTransformation):

Expand Down Expand Up @@ -590,6 +632,7 @@ class FortranIntrinsics:
"SELECTED_INT_KIND": SelectedKind,
"SELECTED_REAL_KIND": SelectedKind,
"SUM": Sum,
"PRODUCT": Product,
"ANY": Any,
"COUNT": Count,
"ALL": All
Expand All @@ -599,6 +642,7 @@ class FortranIntrinsics:
"__dace_selected_int_kind": SelectedKind,
"__dace_selected_real_kind": SelectedKind,
"__dace_sum": Sum,
"__dace_product": Product,
"__dace_any": Any,
"__dace_all": All,
"__dace_count": Count
Expand Down
118 changes: 118 additions & 0 deletions tests/fortran/intrinsic_product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

import numpy as np
import pytest

from dace.frontend.fortran import ast_transforms, fortran_parser

def test_fortran_frontend_product_array():
"""
Tests that the generated array map correctly handles offsets.
"""
test_string = """
PROGRAM index_offset_test
implicit none
double precision, dimension(7) :: d
double precision, dimension(3) :: res
CALL index_test_function(d, res)
end
SUBROUTINE index_test_function(d, res)
double precision, dimension(7) :: d
double precision, dimension(3) :: res
res(1) = PRODUCT(d)
res(2) = PRODUCT(d(:))
res(3) = PRODUCT(d(2:5))
END SUBROUTINE index_test_function
"""

# Now test to verify it executes correctly with no offset normalization

sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False)
sdfg.simplify(verbose=True)
sdfg.compile()

size = 7
d = np.full([size], 0, order="F", dtype=np.float64)
for i in range(size):
d[i] = i + 1
res = np.full([3], 42, order="F", dtype=np.float64)
sdfg(d=d, res=res)
print(d)
print(res)
assert res[0] == np.prod(d)
assert res[1] == np.prod(d)
assert res[2] == np.prod(d[1:5])

def test_fortran_frontend_product_array_dim():
test_string = """
PROGRAM intrinsic_count_test
implicit none
logical, dimension(5) :: d
logical, dimension(2) :: res
CALL intrinsic_count_test_function(d, res)
end
SUBROUTINE intrinsic_count_test_function(d, res)
logical, dimension(5) :: d
logical, dimension(2) :: res
res(1) = PRODUCT(d, 1)
END SUBROUTINE intrinsic_count_test_function
"""

with pytest.raises(NotImplementedError):
fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False)

def test_fortran_frontend_product_2d():
"""
Tests that the generated array map correctly handles offsets.
"""
test_string = """
PROGRAM index_offset_test
implicit none
double precision, dimension(5,3) :: d
double precision, dimension(4) :: res
CALL index_test_function(d,res)
end
SUBROUTINE index_test_function(d, res)
double precision, dimension(5,3) :: d
double precision, dimension(4) :: res
res(1) = PRODUCT(d)
res(2) = PRODUCT(d(:,:))
res(3) = PRODUCT(d(2:4, 2))
res(4) = PRODUCT(d(2:4, 2:3))
END SUBROUTINE index_test_function
"""

# Now test to verify it executes correctly with no offset normalization

sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True)
sdfg.simplify(verbose=True)
sdfg.compile()

sizes = [5, 3]
d = np.full(sizes, 42, order="F", dtype=np.float64)
cnt = 1
for i in range(sizes[0]):
for j in range(sizes[1]):
d[i, j] = cnt
cnt += 1
res = np.full([4], 42, order="F", dtype=np.float64)
sdfg(d=d, res=res)
assert res[0] == np.prod(d)
assert res[1] == np.prod(d)
assert res[2] == np.prod(d[1:4, 1])
assert res[3] == np.prod(d[1:4, 1:3])

if __name__ == "__main__":

test_fortran_frontend_product_array()
test_fortran_frontend_product_array_dim()
test_fortran_frontend_product_2d()

0 comments on commit 8fd8ba2

Please sign in to comment.