Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive inlining via InlineTransform and associated fixes #205

Merged
merged 13 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions loki/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loki.transform.transformation import * # noqa
from loki.transform.transform_utilities import * # noqa
from loki.transform.transform_array_indexing import * # noqa
from loki.transform.transform_associates import * # noqa
from loki.transform.transform_inline import * # noqa
from loki.transform.transform_loop import * # noqa
from loki.transform.transform_region import * # noqa
Expand All @@ -20,5 +19,5 @@
from loki.transform.transform_hoist_variables import * # noqa
from loki.transform.transform_parametrise import * # noqa
from loki.transform.transform_extract_contained_procedures import * # noqa
from loki.transform.transform_sequence_association import * # noqa
from loki.transform.transform_dead_code import * # noqa
from loki.transform.transform_sanitise import * # noqa
2 changes: 1 addition & 1 deletion loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
resolve_vector_notation, normalize_array_shape_and_access,
flatten_arrays
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics, sanitise_imports
)
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/fortran_python_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.transform.transform_array_indexing import (
shift_to_zero_indexing, invert_array_indices, normalize_range_indexing
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics
)
Expand Down
69 changes: 0 additions & 69 deletions loki/transform/transform_associates.py

This file was deleted.

143 changes: 132 additions & 11 deletions loki/transform/transform_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from loki.tools import as_tuple
from loki.logging import warning, error
from loki.pragma_utils import pragmas_attached, is_loki_pragma
from loki.subroutine import Subroutine

from loki.transform.transformation import Transformation
from loki.transform.transform_dead_code import dead_code_elimination
from loki.transform.transform_utilities import (
single_variable_declaration,
recursive_expression_map_update
Expand All @@ -32,10 +35,89 @@
__all__ = [
'inline_constant_parameters', 'inline_elemental_functions',
'inline_internal_procedures', 'inline_member_procedures',
'inline_marked_subroutines'
'inline_marked_subroutines', 'InlineTransformation'
]


class InlineTransformation(Transformation):
"""
:any:`Transformation` class to apply several types of source inlining
when batch-processing large source trees via the :any:`Scheduler`.

Parameters
----------
inline_constants : bool
Replace instances of variables with known constant values by
:any:`Literal` (see :any:`inline_constant_parameters`); default: False.
inline_elementals : bool
Replaces :any:`InlineCall` expression to elemental functions
with the called function's body (see :any:`inline_elemental_functions`);
default: True.
inline_internals : bool
Inline internal procedure (see :any:`inline_internal_procedures`);
default: False.
inline_marked : bool
Inline :any:`Subroutine` objects marked by pragma annotations
(see :any:`inline_marked_subroutines`); default: True.
eliminate_dead_code : bool
Perform dead code elimination, where unreachable branches are
trimmed from the code (see :any:`dead_code_elimination`); default: True
allowed_aliases : tuple or list of str or :any:`Expression`, optional
List of variables that will not be renamed in the parent scope during
internal and pragma-driven inlining.
remove_imports : bool
Strip unused import symbols after pragma-inlining (optional, default: True)
external_only : bool, optional
Do not replace variables declared in the local scope when
inlining constants (default: True)
"""

# Ensure correct recursive inlining by traversing from the leaves
reverse_traversal = True

def __init__(
self, inline_constants=False, inline_elementals=True,
inline_internals=False, inline_marked=True,
eliminate_dead_code=True, allowed_aliases=None,
remove_imports=True, external_only=True
):
self.inline_constants = inline_constants
self.inline_elementals = inline_elementals
self.inline_internals = inline_internals
self.inline_marked = inline_marked

self.eliminate_dead_code = eliminate_dead_code

self.allowed_aliases = allowed_aliases
self.remove_imports = remove_imports
self.external_only = external_only

def transform_subroutine(self, routine, **kwargs):

# Replace constant parameter variables with explicit values
if self.inline_constants:
inline_constant_parameters(routine, external_only=self.external_only)

# Inline elemental functions
if self.inline_elementals:
inline_elemental_functions(routine)

# Inline internal (contained) procedures
if self.inline_internals:
inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases)

# Inline explicitly pragma-marked subroutines
if self.inline_marked:
inline_marked_subroutines(
routine, allowed_aliases=self.allowed_aliases,
remove_imports=self.remove_imports
)

# After inlining, attempt to trim unreachable code paths
if self.eliminate_dead_code:
dead_code_elimination(routine)


class InlineSubstitutionMapper(LokiIdentityMapper):
"""
An expression mapper that defines symbolic substitution for inlining.
Expand Down Expand Up @@ -101,15 +183,23 @@ def inline_constant_parameters(routine, external_only=True):
"""
Replace instances of variables with known constant values by `Literals`.

:param external_only: Do not replace variables declared in the local scope
Notes
-----
The ``.type.initial`` property is used to derive the replacement
value,a which means for symbols imported from external modules,
the parent :any:`Module` needs to be supplied in the
``definitions`` to the constructor when creating :param:`routine`.

Note, the `.type.initial` property is used to derive the replacement value,
which means for symbols imported from external modules, the parent `Module`
needs to be supplied in the `definitions` to the constructor when creating
:param routine:.
Variables that are replaced are also removed from their
corresponding import statements, with empty import statements
being removed alltogether.

Variables that are replaced are also removed from their corresponding import
statements, with empty import statements being removed alltogether.
Parameters
----------
routine : :any:`Subroutine`
Procedure in which to inline/resolve constant parameters.
external_only : bool, optional
Do not replace variables declared in the local scope (default: True)
"""
# Find all variable instances in spec and body
variables = FindVariables().visit(routine.ir)
Expand Down Expand Up @@ -179,7 +269,10 @@ def inline_elemental_functions(routine):

exprmap = {}
for call in FindInlineCalls().visit(routine.body):
if call.procedure_type is not BasicType.DEFERRED:
if call.procedure_type is BasicType.DEFERRED:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptual question, not necessary to act upon here: Going forward, should we try to include debug output (or a new "optimisation report"-like output stream) in situations like this, where a transformation is not applied because an assumption isn't met (here: missing enrichment information).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm planning an overhaul of the logging from transformations once the configuration changes are more complete (so we can select log-leve per trafo, etc.).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly worth capturing via an issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue

if call.procedure_type.is_function and call.procedure_type.is_elemental:
# Map each call to its substitutions, as defined by the
# recursive inline substitution mapper
exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
Expand All @@ -193,7 +286,7 @@ def inline_elemental_functions(routine):
# Remove all module imports that have become obsolete now
import_map = {}
for im in FindNodes(Import).visit(routine.spec):
if all(hasattr(s, 'type') and s.type.dtype in removed_functions for s in im.symbols):
if im.symbols and all(s.type.dtype in removed_functions for s in im.symbols):
import_map[im] = None
routine.spec = Transformer(import_map).visit(routine.spec)

Expand Down Expand Up @@ -319,6 +412,7 @@ def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None):

# Ensure we process sets of calls to the same callee
assert all(call.routine == callee for call in calls)
assert isinstance(callee, Subroutine)

# Prevent shadowing of callee's variables by renaming them a priori
parent_variables = routine.variable_map
Expand Down Expand Up @@ -397,7 +491,7 @@ def inline_internal_procedures(routine, allowed_aliases=None):
inline_member_procedures = inline_internal_procedures


def inline_marked_subroutines(routine, allowed_aliases=None):
def inline_marked_subroutines(routine, allowed_aliases=None, remove_imports=True):
"""
Inline :any:`Subroutine` objects guided by pragma annotations.

Expand All @@ -416,19 +510,46 @@ def inline_marked_subroutines(routine, allowed_aliases=None):
allowed_aliases : tuple or list of str or :any:`Expression`, optional
List of variables that will not be renamed in the parent scope, even
if they alias with a local declaration.
remove_imports : bool
Strip unused import symbols after inlining (optional, default: True)
"""

with pragmas_attached(routine, node_type=CallStatement):

# Group the marked calls by callee routine
call_sets = defaultdict(list)
no_call_sets = defaultdict(list)
for call in FindNodes(CallStatement).visit(routine.body):
if call.routine == BasicType.DEFERRED:
continue

if is_loki_pragma(call.pragma, starts_with='inline'):
call_sets[call.routine].append(call)
else:
no_call_sets[call.routine].append(call)

# Trigger per-call inlining on collected sets
for callee, calls in call_sets.items():
if callee: # Skip the unattached calls (collected under None)
inline_subroutine_calls(
routine, calls, callee, allowed_aliases=allowed_aliases
)

# Remove imported symbols that have become obsolete
if remove_imports:
callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
not_inlined = tuple(callee.procedure_symbol for callee in no_call_sets.keys())

import_map = {}
for impt in FindNodes(Import).visit(routine.spec):
# Remove interface header imports
if any(f'{c.name.lower()}.intfb.h' == impt.module for c in callees):
import_map[impt] = None

if any(s.name in callees for s in impt.symbols):
new_symbols = tuple(
s for s in impt.symbols if s.name not in callees or s.name in not_inlined
)
# Remove import if no further symbols used, otherwise clone with new symbols
import_map[impt] = impt.clone(symbols=new_symbols) if new_symbols else None
routine.spec = Transformer(import_map).visit(routine.spec)
Loading
Loading