Skip to content

Commit

Permalink
Cherry-picked a handful of intrinsic related commits out of `multi_sd…
Browse files Browse the repository at this point in the history
…fg` branch. (#1728)

These commits:
1. are chronologically early
2. fairly independent and touch a small number of files
3. thematically fit together.

This is a follow-up on the discussion about incrementally moving the
stable features from `multi_sdfg` branch into `main`. All of the commits
are cherry-picked to preserve the history and to have easy
fast-forwarding of the `multi_sdfg` branch.

Some commits had to be slightly modified to resolve the merge conflicts
during cherry-picking, but otherwise they are left as-is.

---------

Co-authored-by: Marcin Copik <[email protected]>
  • Loading branch information
pratyai and mcopik authored Nov 6, 2024
1 parent 7c70423 commit 72ee732
Show file tree
Hide file tree
Showing 7 changed files with 1,475 additions and 114 deletions.
27 changes: 19 additions & 8 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self, funcs=None):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
self.excepted_funcs = [
"malloc", "exp", "pow", "sqrt", "cbrt", "max", "abs", "min", "__dace_sign", "tanh",
"malloc", "pow", "cbrt", "__dace_sign", "tanh", "atan2",
"__dace_epsilon", *FortranIntrinsics.function_names()
]

Expand Down Expand Up @@ -220,7 +220,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if not stop and node.name.name not in [
"malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
"malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
]:
self.nodes.append(node)
return self.generic_visit(node)
Expand All @@ -241,7 +241,7 @@ def __init__(self, count=0):
def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]:
if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]:
return self.generic_visit(node)
if hasattr(node, "subroutine"):
if node.subroutine is True:
Expand All @@ -251,6 +251,11 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
else:
self.count = self.count + 1
tmp = self.count

for i, arg in enumerate(node.args):
# Ensure we allow to extract function calls from arguments
node.args[i] = self.visit(arg)

return ast_internal_classes.Name_Node(name="tmp_call_" + str(tmp - 1))

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
Expand All @@ -263,9 +268,13 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
for i in res:
if i == child:
res.pop(res.index(i))
temp = self.count
if res is not None:
for i in range(0, len(res)):
# Variables are counted from 0...end, starting from main node, to all calls nested
# in main node arguments.
# However, we need to define nested ones first.
# We go in reverse order, counting from end-1 to 0.
temp = self.count + len(res) - 1
for i in reversed(range(0, len(res))):

newbody.append(
ast_internal_classes.Decl_Stmt_Node(vardecl=[
Expand All @@ -282,7 +291,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
type=res[i].type),
rval=res[i],
line_number=child.line_number))
temp = temp + 1
temp = temp - 1
if isinstance(child, ast_internal_classes.Call_Expr_Node):
new_args = []
if hasattr(child, "args"):
Expand Down Expand Up @@ -368,7 +377,8 @@ def __init__(self):
self.nodes: List[ast_internal_classes.Array_Subscript_Node] = []

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]:
return self.generic_visit(node)
else:
return
Expand Down Expand Up @@ -401,7 +411,8 @@ def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = Fa
self.scope_vars.visit(ast)

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]:
return self.generic_visit(node)
else:
return node
Expand Down
12 changes: 8 additions & 4 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con
calls.visit(node)
if len(calls.nodes) == 1:
augmented_call = calls.nodes[0]
if augmented_call.name.name not in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh", "__dace_epsilon"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", *FortranIntrinsics.retained_function_names()]:
augmented_call.args.append(node.lval)
augmented_call.hasret = True
self.call2sdfg(augmented_call, sdfg, cfg)
Expand Down Expand Up @@ -1090,7 +1091,8 @@ def create_ast_from_string(
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics().transformations():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
Expand Down Expand Up @@ -1126,7 +1128,8 @@ def create_sdfg_from_string(
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics().transformations():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
Expand Down Expand Up @@ -1172,7 +1175,8 @@ def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_block
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program).visit(program)
Expand Down
Loading

0 comments on commit 72ee732

Please sign in to comment.