Skip to content

Commit

Permalink
feat[dace]: Modified gt_simplify() (#1647)
Browse files Browse the repository at this point in the history
Before the pass was just calling the native DaCe version. Now it will
not call the `PromoteScalarToSymbol` and `ConstantPropagation` pass.
This is because the lowering sometimes has to change between a symbol
and a scalar and back and back again and so on. Furthermore, it looks
like these passes have problems with that, so we excluded them.

This is a temporary solution, at the end, it might be feasible or good
to run the full simplify pass.
  • Loading branch information
philip-paul-mueller authored Sep 18, 2024
1 parent 08e063c commit 07e2ee9
Showing 1 changed file with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

"""Fast access to the auto optimization on DaCe."""

from typing import Any, Optional, Sequence
from typing import Any, Final, Iterable, Optional, Sequence

import dace
from dace.transformation import dataflow as dace_dataflow
Expand All @@ -21,35 +21,45 @@
)


GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"}
"""Set of simplify passes `gt_simplify()` skips by default.
The following passes are included:
- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a
symbol or vice versa and at a later point to invert this again. However, this
pass has some problems with this pattern so for the time being it is disabled.
- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`.
"""


def gt_simplify(
sdfg: dace.SDFG,
validate: bool = True,
validate_all: bool = False,
skip: Optional[set[str]] = None,
skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET,
) -> Any:
"""Performs simplifications on the SDFG in place.
Instead of calling `sdfg.simplify()` directly, you should use this function,
as it is specially tuned for GridTool based SDFGs.
By default this function will run the normal DaCe simplify pass, but skip
passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it
will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not
added automatically.
Args:
sdfg: The SDFG to optimize.
validate: Perform validation after the pass has run.
validate_all: Perform extensive validation.
skip: List of simplify passes that should not be applied.
Note:
The reason for this function is that we can influence how simplify works.
Since some parts in simplify might break things in the SDFG.
However, currently nothing is customized yet, and the function just calls
the simplification pass directly.
skip: List of simplify passes that should not be applied, defaults
to `GT_SIMPLIFY_DEFAULT_SKIP_SET`.
"""

return dace_passes_simplify.SimplifyPass(
validate=validate,
validate_all=validate_all,
verbose=False,
skip=skip,
skip=set(skip) if skip is not None else skip,
).apply_pass(sdfg, {})


Expand Down

0 comments on commit 07e2ee9

Please sign in to comment.