diff --git a/loki/transformations/raw_stack_allocator.py b/loki/transformations/raw_stack_allocator.py index 8b7983340..1d7b9768a 100644 --- a/loki/transformations/raw_stack_allocator.py +++ b/loki/transformations/raw_stack_allocator.py @@ -67,6 +67,9 @@ class TemporariesRawStackTransformation(Transformation): directive : str, optional Can be ``'openmp'`` or ``'openacc'``. If given, insert data sharing clauses for the stack derived type, and insert data transfer statements (for OpenACC only). + driver_horizontal : str, optional + Override string if a separate variable name should be used for the horizontal + when allocating the stack in the driver. key : str, optional Overwrite the key that is used to store analysis results in ``trafo_data``. """ @@ -82,16 +85,18 @@ class TemporariesRawStackTransformation(Transformation): BasicType.INTEGER: {'kernel': 'K', 'driver': 'I'} } - def __init__(self, block_dim, horizontal, - stack_name='STACK', - local_int_var_name_pattern='JD_{name}', - directive=None, key=None, **kwargs): + def __init__( + self, block_dim, horizontal, stack_name='STACK', + local_int_var_name_pattern='JD_{name}', directive=None, + key=None, driver_horizontal=None, **kwargs + ): super().__init__(**kwargs) self.block_dim = block_dim self.horizontal = horizontal self.stack_name = stack_name self.local_int_var_name_pattern = local_int_var_name_pattern self.directive = directive + self.driver_horizontal = driver_horizontal if key: self._key = key @@ -229,8 +234,16 @@ def create_stacks_driver(self, routine, stack_dict, successors): #Create the stack variable and its type with the correct shape stack_var = self._get_stack_var(routine, dtype, kind) - stack_type = stack_var.type.clone(shape=(self._get_horizontal_variable(routine), - stack_dict[dtype][kind], kgpblock)) + horizontal_size = self._get_horizontal_variable(routine) + if self.driver_horizontal: + # If override is specified, use a separate horizontal in the driver + horizontal_size = Variable( + name=self.driver_horizontal, scope=routine, type=self.int_type + ) + + stack_type = stack_var.type.clone( + shape=(horizontal_size, stack_dict[dtype][kind], kgpblock) + ) stack_var = stack_var.clone(type=stack_type) #Add the variables to the stack_arg_dict with dimensions (:,:,j_block) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index de8eb3528..088c62de9 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -324,7 +324,8 @@ def transform_subroutine(self, routine, **kwargs): )) transformation = TemporariesRawStackTransformation( - block_dim=block_dim, horizontal=horizontal, directive=directive + block_dim=block_dim, horizontal=horizontal, + directive=directive, driver_horizontal='NPROMA' ) scheduler.process(transformation=transformation)