Skip to content

Commit

Permalink
SCC: Mark driver loops in SCCRevector and SCCAnnotate translates
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Aug 8, 2024
1 parent fecac7d commit 7121819
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 60 deletions.
71 changes: 26 additions & 45 deletions loki/transformations/single_column/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from loki.ir import (
nodes as ir, FindNodes, Transformer, pragmas_attached,
pragma_regions_attached, is_loki_pragma
pragma_regions_attached, is_loki_pragma, get_pragma_parameters
)
from loki.logging import info
from loki.tools import as_tuple, flatten
Expand Down Expand Up @@ -211,29 +211,16 @@ def process_driver(self, routine, targets=None):
the transformation call tree.
"""

# For the thread block size, find the horizontal size variable that is available in
# the driver
num_threads = None
symbol_map = routine.symbol_map
for size_expr in self.horizontal.size_expressions:
if size_expr in symbol_map:
num_threads = size_expr
break
# Mark all parallel vector loops as `!$acc loop vector`
self.kernel_annotate_vector_loops_openacc(routine)

# Mark all non-parallel loops as `!$acc loop seq`
self.kernel_annotate_sequential_loops_openacc(routine)

with pragmas_attached(routine, ir.Loop, attach_pragma_post=True):
driver_loops = find_driver_loops(routine=routine, targets=targets)
for loop in driver_loops:
loops = FindNodes(ir.Loop).visit(loop.body)
kernel_loops = [l for l in loops if l.variable == self.horizontal.index]
if kernel_loops:
assert not loop == kernel_loops[0]
self.annotate_driver(
self.directive, loop, kernel_loops, self.block_dim, num_threads
)

if self.directive == 'openacc':
# Mark all non-parallel loops as `!$acc loop seq`
self.kernel_annotate_sequential_loops_openacc(routine)
self.annotate_driver(self.directive, loop, self.block_dim)

@classmethod
def device_alloc_column_locals(cls, routine, column_locals):
Expand All @@ -257,7 +244,7 @@ def device_alloc_column_locals(cls, routine, column_locals):
routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))

@classmethod
def annotate_driver(cls, directive, driver_loop, kernel_loops, block_dim, num_threads):
def annotate_driver(cls, directive, driver_loop, block_dim):
"""
Annotate driver block loop with ``'openacc'`` pragmas.
Expand All @@ -273,8 +260,6 @@ def annotate_driver(cls, directive, driver_loop, kernel_loops, block_dim, num_th
block_dim : :any:`Dimension`
Optional ``Dimension`` object to define the blocking dimension
to detect hoisted temporary arrays and excempt them from marking.
num_threads : str
The size expression that determines the number of threads per thread block
"""

# Mark driver loop as "gang parallel".
Expand All @@ -289,25 +274,21 @@ def annotate_driver(cls, directive, driver_loop, kernel_loops, block_dim, num_th
arrays = [v for v in arrays if not any(d in sizes for d in as_tuple(v.shape))]
private_arrays = ', '.join(set(v.name for v in arrays))
private_clause = '' if not private_arrays else f' private({private_arrays})'
vector_length_clause = '' if not num_threads else f' vector_length({num_threads})'

# Annotate vector loops with OpenACC pragmas
if kernel_loops:
for loop in as_tuple(kernel_loops):
loop._update(pragma=(ir.Pragma(keyword='acc', content='loop vector'),))

if driver_loop.pragma is None or (len(driver_loop.pragma) == 1 and
driver_loop.pragma[0].keyword.lower() == "loki" and
driver_loop.pragma[0].content.lower() == "driver-loop"):
p_content = f'parallel loop gang{private_clause}{vector_length_clause}'
driver_loop._update(pragma=(ir.Pragma(keyword='acc', content=p_content),))
driver_loop._update(pragma_post=(ir.Pragma(keyword='acc', content='end parallel loop'),))

# add acc parallel loop gang if the only existing pragma is acc data
elif len(driver_loop.pragma) == 1:
if (driver_loop.pragma[0].keyword == 'acc' and
driver_loop.pragma[0].content.lower().lstrip().startswith('data ')):
p_content = f'parallel loop gang{private_clause}{vector_length_clause}'
driver_loop._update(pragma=(driver_loop.pragma[0], ir.Pragma(keyword='acc', content=p_content)))
driver_loop._update(pragma_post=(ir.Pragma(keyword='acc', content='end parallel loop'),
driver_loop.pragma_post[0]))

for pragma in as_tuple(driver_loop.pragma):
if is_loki_pragma(pragma, starts_with='loop driver'):
# Replace `!$loki loop driver` pragma with OpenACC equivalent
params = get_pragma_parameters(driver_loop.pragma, starts_with='loop driver')
vlength = params.get('vector_length')
vlength_clause = f' vector_length({vlength})' if vlength else ''

content = f'parallel loop gang{private_clause}{vlength_clause}'
pragma_new = ir.Pragma(keyword='acc', content=content)
pragma_post = ir.Pragma(keyword='acc', content='end parallel loop')

# Replace existing loki pragma and add post-pragma
loop_pragmas = tuple(p for p in as_tuple(driver_loop.pragma) if p is not pragma)
driver_loop._update(
pragma=loop_pragmas + (pragma_new,),
pragma_post=(pragma_post,) + as_tuple(driver_loop.pragma_post)
)
32 changes: 20 additions & 12 deletions loki/transformations/single_column/tests/test_scc_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_scc_revector_transformation(frontend, horizontal):
scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
scc_transform += (SCCRevectorTransformation(horizontal=horizontal),)
for transform in scc_transform:
transform.apply(driver, role='driver')
transform.apply(driver, role='driver', targets=('compute_column',))
transform.apply(kernel, role='kernel')

# Ensure we have two nested loops in the kernel
Expand Down Expand Up @@ -126,11 +126,15 @@ def test_scc_revector_transformation(frontend, horizontal):
sections = FindNodes(ir.Section).visit(kernel.body)
assert all(not s.label for s in sections)

# Ensure driver remains unaffected
driver_loops = FindNodes(ir.Loop).visit(driver.body)
assert len(driver_loops) == 1
assert driver_loops[0].variable == 'b'
assert driver_loops[0].bounds == '1:nb'
# Ensure driver remains unaffected and is marked
with pragmas_attached(driver, node_type=ir.Loop):
driver_loops = FindNodes(ir.Loop).visit(driver.body)
assert len(driver_loops) == 1
assert driver_loops[0].variable == 'b'
assert driver_loops[0].bounds == '1:nb'
assert driver_loops[0].pragma and len(driver_loops[0].pragma) == 1
assert is_loki_pragma(driver_loops[0].pragma[0], starts_with='loop driver')
assert 'vector_length(nlon)' in driver_loops[0].pragma[0].content

kernel_calls = FindNodes(ir.CallStatement).visit(driver_loops[0])
assert len(kernel_calls) == 1
Expand Down Expand Up @@ -209,7 +213,7 @@ def test_scc_revector_transformation_aliased_bounds(frontend, horizontal_bounds_
scc_transform = (SCCDevectorTransformation(horizontal=horizontal_bounds_aliases),)
scc_transform += (SCCRevectorTransformation(horizontal=horizontal_bounds_aliases),)
for transform in scc_transform:
transform.apply(driver, role='driver')
transform.apply(driver, role='driver', targets=('compute_column',))
transform.apply(kernel, role='kernel')

# Ensure we have two nested loops in the kernel
Expand Down Expand Up @@ -239,11 +243,15 @@ def test_scc_revector_transformation_aliased_bounds(frontend, horizontal_bounds_
sections = FindNodes(ir.Section).visit(kernel.body)
assert all(not s.label for s in sections)

# Ensure driver remains unaffected
driver_loops = FindNodes(ir.Loop).visit(driver.body)
assert len(driver_loops) == 1
assert driver_loops[0].variable == 'b'
assert driver_loops[0].bounds == '1:nb'
# Ensure driver remains unaffected and is marked
with pragmas_attached(driver, node_type=ir.Loop):
driver_loops = FindNodes(ir.Loop).visit(driver.body)
assert len(driver_loops) == 1
assert driver_loops[0].variable == 'b'
assert driver_loops[0].bounds == '1:nb'
assert driver_loops[0].pragma and len(driver_loops[0].pragma) == 1
assert is_loki_pragma(driver_loops[0].pragma[0], starts_with='loop driver')
assert 'vector_length(nlon)' in driver_loops[0].pragma[0].content

kernel_calls = FindNodes(ir.CallStatement).visit(driver_loops[0])
assert len(kernel_calls) == 1
Expand Down
28 changes: 27 additions & 1 deletion loki/transformations/single_column/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,29 @@ def mark_seq_loops(self, section):
if loop.variable != self.horizontal.index:
loop._update(pragma=(ir.Pragma(keyword='loki', content='loop seq'),))

def mark_driver_loop(self, routine, loop):
"""
Add ``!$loki loop driver`` pragmas to outer block loops and
add ``vector-length(size)`` clause for later annotations.
This method assumes that pragmas have been attached via
:any:`pragmas_attached`.
"""
# Find a horizontal size variable to mark vector_length
symbol_map = routine.symbol_map
sizes = tuple(
symbol_map.get(size) for size in self.horizontal.size_expressions
if size in symbol_map
)
vector_length = f' vector_length({sizes[0]})' if sizes else ''

# Replace existing `!$loki loop driver markers, but leave all others
pragma = ir.Pragma(keyword='loki', content=f'loop driver{vector_length}')
loop_pragmas = tuple(
p for p in as_tuple(loop.pragma) if not is_loki_pragma(p, starts_with='driver-loop')
)
loop._update(pragma=loop_pragmas + (pragma,))

def transform_subroutine(self, routine, **kwargs):
"""
Wrap vector-parallel sections in vector :any:`Loop` objects.
Expand Down Expand Up @@ -372,7 +395,7 @@ def transform_subroutine(self, routine, **kwargs):
self.mark_seq_loops(routine.body)

if role == 'driver':
with pragmas_attached(routine, ir.Loop, attach_pragma_post=True):
with pragmas_attached(routine, ir.Loop):
driver_loops = find_driver_loops(routine=routine, targets=targets)

for loop in driver_loops:
Expand All @@ -382,6 +405,9 @@ def transform_subroutine(self, routine, **kwargs):
# Mark sequential loops inside vector sections
self.mark_seq_loops(loop.body)

# Mark outer driver loops
self.mark_driver_loop(routine, loop)


class SCCDemoteTransformation(Transformation):
"""
Expand Down
5 changes: 3 additions & 2 deletions loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from loki.ir import (
nodes as ir, Import, TypeDef, VariableDeclaration,
StatementFunction, Transformer, FindNodes
StatementFunction, Transformer, FindNodes, is_loki_pragma
)
from loki.module import Module
from loki.subroutine import Subroutine
Expand Down Expand Up @@ -585,7 +585,8 @@ def is_driver_loop(loop, targets):
"""
if loop.pragma:
for pragma in loop.pragma:
if pragma.keyword.lower() == "loki" and pragma.content.lower() == "driver-loop":
if is_loki_pragma(pragma, starts_with='driver-loop') or \
is_loki_pragma(pragma, starts_with='loop driver'):
return True
for call in FindNodes(ir.CallStatement).visit(loop.body):
if call.name in targets:
Expand Down

0 comments on commit 7121819

Please sign in to comment.