diff --git a/loki/transformations/block_index_transformations.py b/loki/transformations/block_index_transformations.py index 067c2a42a..6245db6eb 100644 --- a/loki/transformations/block_index_transformations.py +++ b/loki/transformations/block_index_transformations.py @@ -541,9 +541,12 @@ def arg_to_local_var(routine, var): type=routine.variable_map[var.name].type.clone(intent=None)),) def local_var(self, call, var): - if var.name in call.arg_map: - self.arg_to_local_var(call.routine, call.arg_map[var.name]) - elif var.name in call.routine.arguments: + # if var.name in call.arg_map: + # print(f"arg to local var [1] {var} | {call.arg_map[var.name]}") + # self.arg_to_local_var(call.routine, call.arg_map[var.name]) + # elif var.name in call.routine.arguments: + if var.name in call.routine.arguments: + print(f"arg to local var [2] {var}") self.arg_to_local_var(call.routine, var) else: call.routine.variables += (var.clone(scope=call.routine),) @@ -576,9 +579,11 @@ def process_driver(self, routine, targets): processed_routines = () calls = () for loop in loops: + lower_loop = False for call in FindNodes(ir.CallStatement).visit(loop.body): #visit(routine.body): if str(call.name).lower() not in targets: continue + lower_loop = True calls += (call,) # take a copy of the loop that will be lowered loop_to_lower = loop.clone() @@ -587,8 +592,12 @@ def process_driver(self, routine, targets): loop_to_lower = SubstituteExpressions(call_arg_map).visit(loop_to_lower) # remove calls that are not within targets # TODO: rather a hack to remove # "CALL TIMER%THREAD_LOG(TID, IGPC=ICEND)" - calls_within_loop = [call for call in FindNodes(ir.CallStatement).visit(loop_to_lower.body) - if str(call.name).lower() not in targets] + calls_within_loop = [_call for _call in FindNodes(ir.CallStatement).visit(loop_to_lower.body) + if str(_call.name).lower() not in targets] + loop_to_lower = Transformer({call: None for call in calls_within_loop}).visit(loop_to_lower) + # remove calls that are within targets except for relevant one + calls_within_loop = [_call for _call in FindNodes(ir.CallStatement).visit(loop_to_lower.body) + if str(_call.name).lower() in targets and str(_call.name).lower() != str(call.name).lower()] loop_to_lower = Transformer({call: None for call in calls_within_loop}).visit(loop_to_lower) # symbols that are defined or rather assigned within the loop defined_symbols_loop = [assign.lhs for assign in FindNodes(ir.Assignment).visit(loop_to_lower.body)] @@ -619,13 +628,15 @@ def process_driver(self, routine, targets): call._update(pragma=(call.pragma if call.pragma else ()) + call_pragmas) processed_routines += (call.routine.name,) to_local_var[call.routine.name] = defined_symbols_loop + [loop.variable] - driver_loop_map[loop] = loop.body + if lower_loop: + driver_loop_map[loop] = loop.body routine.body = Transformer(driver_loop_map).visit(routine.body) for call in calls: # FindNodes(ir.CallStatement).visit(routine.body): if str(call.name).lower() not in targets: continue # self.local_var(routine, call, loop.variable) for var in to_local_var[call.routine.name]: + print(f"self.local_var for {var}") self.local_var(call, var) # TODO: remove self.remove_openmp_pragmas(routine) diff --git a/loki/transformations/tests/test_block_index_inject.py b/loki/transformations/tests/test_block_index_inject.py index 697ec1e66..35d6d41b7 100644 --- a/loki/transformations/tests/test_block_index_inject.py +++ b/loki/transformations/tests/test_block_index_inject.py @@ -10,12 +10,13 @@ from loki import ( Dimension, gettempdir, Scheduler, OMNI, FindNodes, Assignment, FindVariables, CallStatement, Subroutine, - Item, available_frontends, Module, fgen, symbols as sym, ir + Item, available_frontends, Module, ir, get_pragma_parameters, # fgen ) from loki.transformations import ( BlockViewToFieldViewTransformation, InjectBlockIndexTransformation, LowerBlockIndexTransformation, LowerBlockLoopTransformation ) +from loki.expression import symbols as sym @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal(): @@ -25,7 +26,7 @@ def fixture_horizontal(): @pytest.fixture(scope='module', name='blocking') def fixture_blocking(): - return Dimension(name='blocking', size='nb', index='ibl', index_aliases='bnds%kbl') + return Dimension(name='blocking', size='nb', index='ibl', index_aliases=('bnds%kbl', 'jkglo')) @pytest.fixture(scope='function', name='config') @@ -389,8 +390,9 @@ def test_blockview_to_fieldview_exception(frontend, horizontal): targets=('compute',)) -@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, - 'OMNI correctly complains about rank mismatch in assignment.')])) +# @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, +# 'OMNI correctly complains about rank mismatch in assignment.')])) +@pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('block_dim_arg', (False, True)) @pytest.mark.parametrize('recurse_to_kernels', (False, True)) def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels): @@ -412,7 +414,6 @@ def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels offset = 1 !$omp test do ibl=loop_start, loop_end - ibl = ibl - offset + some_val call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl),offset, loop_start, loop_end{', ibl, nb' if block_dim_arg else ''}) enddo end subroutine driver @@ -428,7 +429,8 @@ def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels integer, intent(in) :: nlon,nlev,icend,lstart,lend real, intent(inout) :: var(nlon,nlev) real, intent(inout) :: another_var(nlon, nlev) - {'integer, intent(in) :: ibl, nb' if block_dim_arg else ''} + {'integer, intent(in) :: ibl' if block_dim_arg else ''} + {'integer, intent(in) :: nb' if block_dim_arg else ''} integer :: jk, jl var(:,:) = 0. do jk = 1,nlev @@ -455,16 +457,17 @@ def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels end module compute_mod """ - # recurse_to_kernels = True # False - # kernel = Subroutine.from_source(fcode, frontend=frontend) nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend) - kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod) driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod) - print(f"kernel.symbol_table: {dict(kernel_mod['kernel'].symbol_attrs)}") - # kernel = Subroutine.from_source(fcode, frontend=frontend) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver, role='driver', targets=('kernel',)) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'], role='kernel') + + # lower block index (dimension/shape) as prerequisite for 'InjectBlockIndexTransformation' + LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver, + role='driver', targets=('kernel',)) + LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'], + role='kernel', targets=('compute',)) + LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'], + role='kernel') kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0] if block_dim_arg: @@ -515,11 +518,16 @@ def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels LowerBlockLoopTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) LowerBlockLoopTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel') - """ + driver_calls = FindNodes(ir.CallStatement).visit(driver.body) + assert driver_calls[0].pragma[0].keyword.lower() == 'loki' + assert 'removed_loop' in driver_calls[0].pragma[0].content.lower() + parameters = get_pragma_parameters(driver_calls[0].pragma, starts_with='removed_loop') + assert parameters == {'var': 'ibl', 'lower': 'loop_start', 'upper': 'loop_end', 'step': '1'} driver_loops = FindNodes(ir.Loop).visit(driver.body) kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body) assert not any(loop.variable == blocking.index for loop in driver_loops) assert any(loop.variable == blocking.index for loop in kernel_loops) + kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0] if block_dim_arg: assert blocking.size in kernel_call.arguments assert blocking.index not in kernel_call.arguments @@ -528,33 +536,51 @@ def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels assert blocking.index not in [kwarg[0] for kwarg in kernel_call.kwarguments] assert blocking.size in kernel_mod['kernel'].arguments assert blocking.index not in kernel_mod['kernel'].arguments - """ + assert blocking.index in kernel_mod['kernel'].variable_map - print(f"---------------\ndriver:\n{fgen(driver)}") - print(f"---------------\nkernel:\n{fgen(kernel_mod['kernel'])}") - print(f"---------------\nkernel:\n{fgen(nested_kernel_mod['compute'])}") - print("\n\n") + # print(f"---------------\ndriver:\n{fgen(driver)}") + # print(f"---------------\nkernel:\n{fgen(kernel_mod['kernel'])}") + # print(f"---------------\nkernel:\n{fgen(nested_kernel_mod['compute'])}") + # print("\n\n") # print(f"kernel.symbol_table: {dict(kernel['kernel'].symbol_attrs)}") # assigns = FindNodes(Assignment).visit(kernel.body) # assert assigns[0].lhs == 'var(:,:,ibl)' # calls = FindNodes(CallStatement).visit(kernel.body) # assert 'var(:,:,ibl)' in calls[0].arguments -@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, - 'OMNI correctly complains about rank mismatch in assignment.')])) -def test_lower_loop(blocking, frontend): +# @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, +# 'OMNI correctly complains about rank mismatch in assignment.')])) +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('recurse_to_kernels', (False, True)) +@pytest.mark.parametrize('targets', (('kernel', 'another_kernel', 'compute'), ('kernel', 'compute'))) +def test_lower_loop(blocking, frontend, recurse_to_kernels, targets): fcode_driver = """ subroutine driver(nlon,nlev,nb,var) use kernel_mod, only: kernel + use another_kernel_mod, only: another_kernel implicit none integer, intent(in) :: nlon,nlev,nb real, intent(inout) :: var(nlon,nlev,nb) real :: some_var(nlon,nlev,nb) - integer :: jkglo, ibl + integer :: jkglo, ibl, status, jk, jl do jkglo=1,nb,nlev ibl = (jkglo-1)/(nlev+1) call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl)) + call another_kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl)) + status = 1 + enddo + do jkglo=1,nb,nlev + ibl = (jkglo-1)/(nlev+1) + call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl)) + enddo + do jkglo=1,nb,nlev + ibl = (jkglo-1)/(nlev+1) + do jk = 1,nlev + do jl = 1, nlon + some_var(jl, jk, jkglo) = 0. + end do + end do enddo end subroutine driver """ @@ -574,6 +600,27 @@ def test_lower_loop(blocking, frontend): call compute(nlon,nlev,another_var) end subroutine kernel end module kernel_mod +""" + + fcode_another_kernel = """ +module another_kernel_mod +implicit none +contains +subroutine another_kernel(nlon,nlev,var,another_var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: var(nlon,nlev) + real, intent(inout) :: another_var(nlon, nlev) + integer :: jk, jl + var(:,:) = 0. + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + another_var(jl, jk) = 0. + end do + end do +end subroutine another_kernel +end module another_kernel_mod """ fcode_nested_kernel = """ @@ -584,35 +631,131 @@ def test_lower_loop(blocking, frontend): implicit none integer, intent(in) :: nlon,nlev real, intent(inout) :: var(nlon,nlev) - var(:,:) = 0. + integer :: jk, jl + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + end do + end do end subroutine compute end module compute_mod """ - recurse_to_kernels = True - # kernel = Subroutine.from_source(fcode, frontend=frontend) nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend) kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod) - driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod) - print(f"kernel.symbol_table: {dict(kernel_mod['kernel'].symbol_attrs)}") - # kernel = Subroutine.from_source(fcode, frontend=frontend) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver, role='driver', targets=('kernel',)) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) - LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'], role='kernel') - InjectBlockIndexTransformation(blocking).apply(driver, role='driver', targets=('kernel',)) - InjectBlockIndexTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + another_kernel_mod = Module.from_source(fcode_another_kernel, frontend=frontend) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=(kernel_mod, another_kernel_mod)) + + LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver, + role='driver', targets=targets) + LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'], + role='kernel', targets=targets) + LowerBlockIndexTransformation(blocking, + recurse_to_kernels=recurse_to_kernels).apply(another_kernel_mod['another_kernel'], + role='kernel', targets=targets) + LowerBlockIndexTransformation(blocking, + recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'], + role='kernel') + + kernel_calls = [call for call in FindNodes(ir.CallStatement).visit(driver.body) + if str(call.name).lower() in targets] + for kernel_call in kernel_calls: + assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments] + assert blocking.index in [kwarg[0] for kwarg in kernel_call.kwarguments] + assert blocking.size in kernel_mod['kernel'].arguments + assert blocking.index in kernel_mod['kernel'].arguments + if 'another_kernel' in targets: + assert blocking.size in another_kernel_mod['another_kernel'].arguments + assert blocking.index in another_kernel_mod['another_kernel'].arguments + + kernel_array_args = [arg for arg in kernel_mod['kernel'].arguments if isinstance(arg, sym.Array)] + another_kernel_array_args = [arg for arg in another_kernel_mod['another_kernel'].arguments + if isinstance(arg, sym.Array)] + nested_kernel_array_args = [arg for arg in nested_kernel_mod['compute'].arguments if isinstance(arg, sym.Array)] + test_array_args = kernel_array_args + test_array_args += another_kernel_array_args if 'another_kernel' in targets else [] + test_array_args += nested_kernel_array_args if recurse_to_kernels else [] + for array in test_array_args: + assert blocking.size in array.dimensions + assert blocking.size in array.shape + if not recurse_to_kernels: + for array in nested_kernel_array_args: + assert blocking.size not in array.dimensions + assert blocking.size not in array.shape + + arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)] + arrays += [var for var in FindVariables().visit(another_kernel_mod['another_kernel'].body) + if isinstance(var, sym.Array)] if 'another_kernel' in targets else [] + arrays += [var for var in FindVariables().visit(nested_kernel_mod['compute'].body) + if isinstance(var, sym.Array)] if recurse_to_kernels else [] + for array in arrays: + if array.name.lower() in [arg.name.lower() for arg in test_array_args]: + assert blocking.size in array.shape + assert blocking.index not in array.dimensions + + InjectBlockIndexTransformation(blocking).apply(driver, role='driver', targets=targets) + InjectBlockIndexTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets) + InjectBlockIndexTransformation(blocking).apply(another_kernel_mod['another_kernel'], role='kernel', targets=targets) InjectBlockIndexTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel') - LowerBlockLoopTransformation(blocking).apply(driver, role='driver', targets=('kernel',)) - LowerBlockLoopTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)] + arrays += [var for var in FindVariables().visit(another_kernel_mod['another_kernel'].body) + if isinstance(var, sym.Array)] if 'another_kernel' in targets else [] + arrays += [var for var in FindVariables().visit(nested_kernel_mod['compute'].body) + if isinstance(var, sym.Array)] if recurse_to_kernels else [] + for array in arrays: + if array.name.lower() in [arg.name.lower() for arg in test_array_args]: + assert blocking.size in array.shape + assert not array.dimensions or blocking.index in array.dimensions + + driver_loops = FindNodes(ir.Loop).visit(driver.body) + kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body) + another_kernel_loops = FindNodes(ir.Loop).visit(another_kernel_mod['another_kernel'].body) + assert any(loop.variable == blocking.index or loop.variable in blocking._index_aliases for loop in driver_loops) + assert not any(loop.variable == blocking.index or loop.variable in blocking._index_aliases for loop in kernel_loops) + if 'another_kernel' in targets: + assert not any(loop.variable == blocking.index or loop.variable + in blocking._index_aliases for loop in another_kernel_loops) + + LowerBlockLoopTransformation(blocking).apply(driver, role='driver', targets=targets) + LowerBlockLoopTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets) + LowerBlockLoopTransformation(blocking).apply(another_kernel_mod['another_kernel'], role='kernel', targets=targets) LowerBlockLoopTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel') - print(f"---------------\ndriver:\n{fgen(driver)}") - print(f"---------------\nkernel:\n{fgen(kernel_mod['kernel'])}") - print(f"---------------\nkernel:\n{fgen(nested_kernel_mod['compute'])}") - print("\n\n") - # print(f"kernel.symbol_table: {dict(kernel['kernel'].symbol_attrs)}") - # assigns = FindNodes(Assignment).visit(kernel.body) - # assert assigns[0].lhs == 'var(:,:,ibl)' - # calls = FindNodes(CallStatement).visit(kernel.body) - # assert 'var(:,:,ibl)' in calls[0].arguments + driver_calls = [call for call in FindNodes(ir.CallStatement).visit(driver.body) if call.pragma is not None] + if 'another_kernel' in targets: + assert len(driver_calls) == 3 + else: + assert len(driver_calls) == 2 + for driver_call in driver_calls: + assert driver_call.pragma[0].keyword.lower() == 'loki' + assert 'removed_loop' in driver_call.pragma[0].content.lower() + parameters = get_pragma_parameters(driver_call.pragma, starts_with='removed_loop') + assert parameters == {'var': 'jkglo', 'lower': '1', 'upper': 'nb', 'step': 'nlev'} + driver_loops = FindNodes(ir.Loop).visit(driver.body) + kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body) + another_kernel_loops = FindNodes(ir.Loop).visit(another_kernel_mod['another_kernel'].body) + assert len([loop for loop in driver_loops if loop.variable == blocking.index + or loop.variable in blocking._index_aliases]) == 1 + assert any(loop.variable == blocking.index or loop.variable in blocking._index_aliases + for loop in kernel_loops) + if 'another_kernel' in targets: + assert any(loop.variable == blocking.index or loop.variable in blocking._index_aliases + for loop in another_kernel_loops) + kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0] + assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments] + assert blocking.index not in [kwarg[0] for kwarg in kernel_call.kwarguments] + for index_alias in blocking._index_aliases: + assert index_alias not in [kwarg[0] for kwarg in kernel_call.kwarguments] + assert index_alias not in kernel_mod['kernel'].arguments + if 'another_kernel' in targets: + assert index_alias not in another_kernel_mod['another_kernel'].arguments + assert blocking.size in kernel_mod['kernel'].arguments + assert blocking.index not in kernel_mod['kernel'].arguments + assert blocking.index in kernel_mod['kernel'].variable_map + + # print(f"---------------\ndriver:\n{fgen(driver)}") + # print(f"---------------\nkernel:\n{fgen(kernel_mod['kernel'])}") + # print(f"---------------\nkernel:\n{fgen(another_kernel_mod['another_kernel'])}") + # print(f"---------------\nkernel:\n{fgen(nested_kernel_mod['compute'])}") + # print("\n\n")