Skip to content

Commit

Permalink
Proper handling of omp pragmas for LowerBlockLoop trafo
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 authored and reuterbal committed Sep 6, 2024
1 parent 8833741 commit 8abab40
Showing 1 changed file with 71 additions and 63 deletions.
134 changes: 71 additions & 63 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
# nor does it submit to any jurisdiction.

from loki.batch import Transformation, ProcedureItem
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.ir import (
nodes as ir, FindNodes, Transformer, pragmas_attached,
pragma_regions_attached
)
from loki.module import Module
from loki.tools import as_tuple # , CaseInsensitiveDict
from loki.types import SymbolAttributes, BasicType
Expand Down Expand Up @@ -524,8 +527,8 @@ def process(self, routine, targets, role):
if str(call.name).lower() not in targets:
continue
if call.routine is BasicType.DEFERRED:
warning(f'[LowerBlockIndexTransformation] Not processing routine ' \
'{call.name}. Call statement not enriched')
warning('[LowerBlockIndexTransformation] Not processing routine ' \
f'{call.name}. Call statement not enriched')
continue
call_arg_map = dict((v,k) for k,v in call.arg_map.items())
call_block_dim_size = call_arg_map.get(block_dim_size, block_dim_size)
Expand Down Expand Up @@ -665,12 +668,6 @@ def local_var(self, call, variables):
if var not in call_routine_variables:
call.routine.variables += (var.clone(scope=call.routine),)

@staticmethod
def remove_openmp_pragmas(routine):
pragmas = [pragma for pragma in FindNodes(ir.Pragma).visit(routine.body) if pragma.keyword.lower() == 'omp']
pragma_map = {pragma: None for pragma in pragmas}
routine.body = Transformer(pragma_map).visit(routine.body)

@staticmethod
def generate_pragma(loop):
return ir.Pragma(keyword="loki", content=f"removed_loop var({loop.variable}) \
Expand All @@ -690,57 +687,68 @@ def update_call_signature(self, call, loop, loop_defined_symbols, additional_kwa

def process_driver(self, routine, targets):
# find block loops
loops = FindNodes(ir.Loop).visit(routine.body)
loops = [loop for loop in loops if loop.variable == self.block_dim.index
or loop.variable in self.block_dim._index_aliases]
driver_loop_map = {}
processed_routines = ()
calls = ()
additional_kwargs = {}
for loop in loops:
target_calls = [call for call in FindNodes(ir.CallStatement).visit(loop.body)
if str(call.name).lower() in targets]
target_calls = [call for call in target_calls if call.routine is not BasicType.DEFERRED]
if not target_calls:
continue
calls += tuple(target_calls)
driver_loop_map[loop] = loop.body
defined_symbols_loop = [assign.lhs for assign in FindNodes(ir.Assignment).visit(loop.body)]
for call in target_calls:
if call.routine.name in processed_routines:
self.update_call_signature(call, loop, defined_symbols_loop, additional_kwargs[call.routine.name])
continue
# 1. Create a copy of the loop with all other call statements removed
other_calls = {c: None for c in FindNodes(ir.CallStatement).visit(loop) if c is not call}
loop_to_lower = Transformer(other_calls).visit(loop)

# 2. Replace all variables according to the caller-callee argument map
call_arg_map = dict((v, k) for k, v in call.arg_map.items())
loop_to_lower = SubstituteExpressions(call_arg_map).visit(loop_to_lower)

# 3. Identify local variables that need to be provided as additional arguments to the call
call_routine_variables = {var.name.lower() for var in FindVariables().visit(call.routine.body)}
call_routine_variables |= {var.name.lower() for var in call.routine.variables}
loop_variables = FindVariables().visit(loop_to_lower.body)
loop_variables = [
var for var in FindVariables().visit(loop_to_lower.body)
if var.name.lower() != loop.variable and var.name.lower() not in call_routine_variables
and var not in call_arg_map and isinstance(var, sym.Scalar) and var not in defined_symbols_loop
]
additional_kwargs[call.routine.name] = {var.name: var for var in loop_variables}

# 4. Inject the loop body into the called routine
call.routine.arguments += tuple(additional_kwargs[call.routine.name].values())
routine_body = Transformer({c: c.routine.body for c in\
FindNodes(ir.CallStatement).visit(loop_to_lower)}).visit(loop_to_lower)
routine_body = AttachScopes().visit(routine_body, scope=call.routine)
call.routine.body = ir.Section(body=as_tuple(routine_body))

# 5. Update the call on the caller side
processed_routines += (call.routine.name,)
self.local_var(call, defined_symbols_loop + [loop.variable])
self.update_call_signature(call, loop, defined_symbols_loop, additional_kwargs[call.routine.name])
driver_loop_map[loop] = loop.body
routine.body = Transformer(driver_loop_map).visit(routine.body)
# TODO: remove
self.remove_openmp_pragmas(routine)
with pragma_regions_attached(routine):
with pragmas_attached(routine, ir.Loop):
loops = FindNodes(ir.Loop).visit(routine.body)
loops = [loop for loop in loops if loop.variable == self.block_dim.index
or loop.variable in self.block_dim._index_aliases]

# Remove parallel regions around block loops
pragma_region_map = {}
for pragma_region in FindNodes(ir.PragmaRegion).visit(routine.body):
for loop in loops:
if loop in pragma_region.body:
pragma_region_map[pragma_region] = pragma_region.body
routine.body = Transformer(pragma_region_map, inplace=True).visit(routine.body)

driver_loop_map = {}
processed_routines = ()
calls = ()
additional_kwargs = {}
for loop in loops:
target_calls = [call for call in FindNodes(ir.CallStatement).visit(loop.body)
if str(call.name).lower() in targets]
target_calls = [call for call in target_calls if call.routine is not BasicType.DEFERRED]
if not target_calls:
continue
calls += tuple(target_calls)
driver_loop_map[loop] = loop.body
defined_symbols_loop = [assign.lhs for assign in FindNodes(ir.Assignment).visit(loop.body)]
for call in target_calls:
if call.routine.name in processed_routines:
self.update_call_signature(call, loop, defined_symbols_loop,
additional_kwargs[call.routine.name])
continue
# 1. Create a copy of the loop with all other call statements removed
other_calls = {c: None for c in FindNodes(ir.CallStatement).visit(loop) if c is not call}
loop_to_lower = Transformer(other_calls).visit(loop)

# 2. Replace all variables according to the caller-callee argument map
call_arg_map = dict((v, k) for k, v in call.arg_map.items())
loop_to_lower = SubstituteExpressions(call_arg_map).visit(loop_to_lower)

# 3. Identify local variables that need to be provided as additional arguments to the call
call_routine_variables = {v.name.lower() for v in FindVariables().visit(call.routine.body)}
call_routine_variables |= {v.name.lower() for v in call.routine.variables}
loop_variables = FindVariables().visit(loop_to_lower.body)
loop_variables = [
v for v in FindVariables().visit(loop_to_lower.body)
if v.name.lower() != loop.variable and v.name.lower() not in call_routine_variables
and v not in call_arg_map and isinstance(v, sym.Scalar) and v not in defined_symbols_loop
]
additional_kwargs[call.routine.name] = {var.name: var for var in loop_variables}

# 4. Inject the loop body into the called routine
call.routine.arguments += tuple(additional_kwargs[call.routine.name].values())
routine_body = Transformer({c: c.routine.body for c in\
FindNodes(ir.CallStatement).visit(loop_to_lower)}).visit(loop_to_lower)
routine_body = AttachScopes().visit(routine_body, scope=call.routine)
call.routine.body = ir.Section(body=as_tuple(routine_body))

# 5. Update the call on the caller side
processed_routines += (call.routine.name,)
self.local_var(call, defined_symbols_loop + [loop.variable])
self.update_call_signature(call, loop, defined_symbols_loop,
additional_kwargs[call.routine.name])
driver_loop_map[loop] = loop.body
routine.body = Transformer(driver_loop_map).visit(routine.body)

0 comments on commit 8abab40

Please sign in to comment.