Skip to content

Commit

Permalink
Small fix for lower loop trafo and improve/extend corresponding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Jun 18, 2024
1 parent c09c352 commit a3b8146
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 52 deletions.
23 changes: 17 additions & 6 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,12 @@ def arg_to_local_var(routine, var):
type=routine.variable_map[var.name].type.clone(intent=None)),)

def local_var(self, call, var):
if var.name in call.arg_map:
self.arg_to_local_var(call.routine, call.arg_map[var.name])
elif var.name in call.routine.arguments:
# if var.name in call.arg_map:
# print(f"arg to local var [1] {var} | {call.arg_map[var.name]}")
# self.arg_to_local_var(call.routine, call.arg_map[var.name])
# elif var.name in call.routine.arguments:
if var.name in call.routine.arguments:
print(f"arg to local var [2] {var}")
self.arg_to_local_var(call.routine, var)
else:
call.routine.variables += (var.clone(scope=call.routine),)
Expand Down Expand Up @@ -576,9 +579,11 @@ def process_driver(self, routine, targets):
processed_routines = ()
calls = ()
for loop in loops:
lower_loop = False
for call in FindNodes(ir.CallStatement).visit(loop.body): #visit(routine.body):
if str(call.name).lower() not in targets:
continue
lower_loop = True
calls += (call,)
# take a copy of the loop that will be lowered
loop_to_lower = loop.clone()
Expand All @@ -587,8 +592,12 @@ def process_driver(self, routine, targets):
loop_to_lower = SubstituteExpressions(call_arg_map).visit(loop_to_lower)
# remove calls that are not within targets # TODO: rather a hack to remove
# "CALL TIMER%THREAD_LOG(TID, IGPC=ICEND)"
calls_within_loop = [call for call in FindNodes(ir.CallStatement).visit(loop_to_lower.body)
if str(call.name).lower() not in targets]
calls_within_loop = [_call for _call in FindNodes(ir.CallStatement).visit(loop_to_lower.body)
if str(_call.name).lower() not in targets]
loop_to_lower = Transformer({call: None for call in calls_within_loop}).visit(loop_to_lower)
# remove calls that are within targets except for relevant one
calls_within_loop = [_call for _call in FindNodes(ir.CallStatement).visit(loop_to_lower.body)
if str(_call.name).lower() in targets and str(_call.name).lower() != str(call.name).lower()]
loop_to_lower = Transformer({call: None for call in calls_within_loop}).visit(loop_to_lower)
# symbols that are defined or rather assigned within the loop
defined_symbols_loop = [assign.lhs for assign in FindNodes(ir.Assignment).visit(loop_to_lower.body)]
Expand Down Expand Up @@ -619,13 +628,15 @@ def process_driver(self, routine, targets):
call._update(pragma=(call.pragma if call.pragma else ()) + call_pragmas)
processed_routines += (call.routine.name,)
to_local_var[call.routine.name] = defined_symbols_loop + [loop.variable]
driver_loop_map[loop] = loop.body
if lower_loop:
driver_loop_map[loop] = loop.body
routine.body = Transformer(driver_loop_map).visit(routine.body)
for call in calls: # FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).lower() not in targets:
continue
# self.local_var(routine, call, loop.variable)
for var in to_local_var[call.routine.name]:
print(f"self.local_var for {var}")
self.local_var(call, var)
# TODO: remove
self.remove_openmp_pragmas(routine)
Loading

0 comments on commit a3b8146

Please sign in to comment.