Skip to content

Commit

Permalink
Implement Fortran COUNT intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
mcopik committed Oct 13, 2023
1 parent 51f149d commit 827bd1f
Show file tree
Hide file tree
Showing 2 changed files with 455 additions and 34 deletions.
120 changes: 86 additions & 34 deletions dace/frontend/fortran/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def replaced_name(func_name: str) -> str:
replacements = {
"SUM": "__dace_sum",
"ANY": "__dace_any",
"ALL": "__dace_all"
"ALL": "__dace_all",
"COUNT": "__dace_count"
}
return replacements[func_name]

Expand All @@ -70,7 +71,8 @@ def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classe
func_types = {
"__dace_sum": "DOUBLE",
"__dace_any": "INTEGER",
"__dace_all": "INTEGER"
"__dace_all": "INTEGER",
"__dace_count": "INTEGER"
}
# FIXME: Any requires sometimes returning an array of booleans
call_type = func_types[func_name.name]
Expand Down Expand Up @@ -280,7 +282,7 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_
line_number=node.line_number
)

class AnyAllTransformation(LoopBasedReplacementTransformation):
class AnyAllCountTransformation(LoopBasedReplacementTransformation):

def __init__(self, ast):
super().__init__(ast)
Expand Down Expand Up @@ -429,11 +431,7 @@ def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[

def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:

init_value = None
if 'any' in self.func_name():
init_value = "0"
else:
init_value = "1"
init_value = self._result_init_value()

return ast_internal_classes.BinOp_Node(
lval=node.lval,
Expand All @@ -443,25 +441,14 @@ def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_c
)

def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node:

"""
For any, we check if the condition is true and then set the value to true
For all, we check if the condition is NOT true and then set the value to false
"""

assign_value = None
if 'any' in self.func_name():
assign_value = "1"
else:
assign_value = "0"

body_if = ast_internal_classes.Execution_Part_Node(execution=[
ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
op="=",
rval=ast_internal_classes.Int_Literal_Node(value=assign_value),
line_number=node.line_number
),
self._result_loop_update(node),
# TODO: we should make the `break` generation conditional based on the architecture
# For parallel maps, we should have no breaks
# For sequential loop, we want a break to be faster
Expand All @@ -470,17 +457,8 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_
#)
])

condition = None
if 'any' in self.func_name():
condition = self.cond
else:
condition = ast_internal_classes.UnOp_Node(
op="not",
lval=self.cond
)

return ast_internal_classes.If_Stmt_Node(
cond=condition,
cond=self._loop_condition(),
body=body_if,
body_else=ast_internal_classes.Execution_Part_Node(execution=[]),
line_number=node.line_number
Expand Down Expand Up @@ -509,11 +487,26 @@ class Any(LoopBasedReplacement):
For (2), we reuse the provided binary operation.
When the condition is true, we set the value to true and exit.
"""
class Transformation(AnyAllTransformation):
class Transformation(AnyAllCountTransformation):

def __init__(self, ast):
super().__init__(ast)

def _result_init_value(self):
return "0"

def _result_loop_update(self, node: ast_internal_classes.FNode):

return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
op="=",
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=node.line_number
)

def _loop_condition(self):
return self.cond

def func_name(self) -> str:
return "__dace_any"

Expand All @@ -525,14 +518,71 @@ class All(LoopBasedReplacement):
The main difference is that we initialize the partial result to 1,
and set it to 0 if any of the evaluated conditions is false.
"""
class Transformation(AnyAllTransformation):
class Transformation(AnyAllCountTransformation):

def __init__(self, ast):
super().__init__(ast)

def _result_init_value(self):
return "1"

def _result_loop_update(self, node: ast_internal_classes.FNode):

return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
op="=",
rval=ast_internal_classes.Int_Literal_Node(value="0"),
line_number=node.line_number
)

def _loop_condition(self):
return ast_internal_classes.UnOp_Node(
op="not",
lval=self.cond
)

def func_name(self) -> str:
return "__dace_all"

class Count(LoopBasedReplacement):

"""
In this class, we implement the transformation for Fortran intrinsic COUNT.
The implementation is very similar to ANY and ALL.
The main difference is that we initialize the partial result to 0
and increment it if any of the evaluated conditions is true.
We do not support the KIND argument.
"""
class Transformation(AnyAllCountTransformation):

def __init__(self, ast):
super().__init__(ast)

def _result_init_value(self):
return "0"

def _result_loop_update(self, node: ast_internal_classes.FNode):

update = ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
op="+",
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=node.line_number
)
return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
op="=",
rval=update,
line_number=node.line_number
)

def _loop_condition(self):
return self.cond

def func_name(self) -> str:
return "__dace_count"


class FortranIntrinsics:

Expand All @@ -541,6 +591,7 @@ class FortranIntrinsics:
"SELECTED_REAL_KIND": SelectedKind,
"SUM": Sum,
"ANY": Any,
"COUNT": Count,
"ALL": All
}

Expand All @@ -549,7 +600,8 @@ class FortranIntrinsics:
"__dace_selected_real_kind": SelectedKind,
"__dace_sum": Sum,
"__dace_any": Any,
"__dace_all": All
"__dace_all": All,
"__dace_count": Count
}

def __init__(self):
Expand Down
Loading

0 comments on commit 827bd1f

Please sign in to comment.