Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Add memory and disk-based caching to more workflow steps #1690

Merged
merged 67 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
7e78704
Fix lowering_utils._expr_hash stability across runs
tehrengruber Feb 28, 2024
7515d07
Pass manager caching
tehrengruber Feb 28, 2024
54bbea0
Fix format
tehrengruber Feb 28, 2024
0cfc583
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
04d8459
Use content_hash
tehrengruber Feb 28, 2024
eac113c
Fix import
tehrengruber Feb 28, 2024
679521a
Small fix
tehrengruber Feb 28, 2024
2e3f17b
Small fix
tehrengruber Feb 28, 2024
2803f41
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
d6adb0e
Small fix
tehrengruber Feb 28, 2024
b39b22f
Small cleanup
tehrengruber Feb 28, 2024
06c0384
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
bbc3b4d
Fix cache dir creation
tehrengruber May 14, 2024
3e0b021
Merge origin/main
tehrengruber Jun 21, 2024
8f5aedd
Add todo
tehrengruber Jun 21, 2024
e039057
Fix typo
tehrengruber Jun 21, 2024
5a8b525
Cleanup
tehrengruber Jun 21, 2024
cf213c4
Add tests
tehrengruber Jun 25, 2024
b724a8d
Merge pull request #1 from GridTools/main
SF-N Jul 8, 2024
6bd5652
Merge remote-tracking branch 'origin/main'
SF-N Oct 10, 2024
ff055f1
Merge remote-tracking branch 'origin/main'
SF-N Oct 15, 2024
cdfd6b5
Extend lap_test to also run benchmark
SF-N Oct 15, 2024
6806f01
Add pytest-benchmark to requirements-dev
SF-N Oct 16, 2024
fb93b73
Disable pytest-xdist to enable benchmarking
SF-N Oct 16, 2024
ad76088
remove -n from tox.ini
SF-N Oct 16, 2024
256d9de
Enable pytest-xdist and run benchmarks separately
SF-N Oct 16, 2024
f00c60a
Change order for faster CI results
SF-N Oct 16, 2024
ccc35d0
Add pytest flags and reformat
SF-N Oct 16, 2024
780a955
Use larger grid for lap_benchmark
SF-N Oct 16, 2024
3effb43
Add pytest.ini for benchmark comparison
SF-N Oct 16, 2024
01bfda8
Add pygal to requirements for histogram
SF-N Oct 16, 2024
9995bec
Remove benchmark comparison
SF-N Oct 16, 2024
afd3e96
Set cached = True in func_to_past_factory and reformat
SF-N Oct 16, 2024
f72c84f
Set cached = False in func_to_past_factory because of missing perform…
SF-N Oct 16, 2024
736a924
WIP remote debugging session
egparedes Oct 16, 2024
833c1e9
Remove debugging output
SF-N Oct 16, 2024
1d3310c
Add further performance improvements with try-except in gtfn and backend
SF-N Oct 16, 2024
90414cc
Revert changes in gtfn.py
SF-N Oct 24, 2024
570385c
Merge origin/main
SF-N Oct 24, 2024
94fcb90
Merge branch 'main' into optimize_program
SF-N Oct 24, 2024
d856bb3
Merge branch 'main' into optimize_program
SF-N Oct 28, 2024
77f4d2f
Merge remote-tracking branch 'origin_tehrengruber/pass_manager_cachin…
SF-N Oct 28, 2024
174c7fc
File caching with CachedStep (still commented in gtfn because of erro…
SF-N Oct 29, 2024
26c7060
Merge branch 'main' into optimize_program
SF-N Oct 30, 2024
7aba4bd
Use diskcache instead of shelve, cleanup and add new test
SF-N Oct 31, 2024
28eb4da
Update caching and tests
SF-N Oct 31, 2024
27a141b
Remove test
SF-N Oct 31, 2024
dc04786
Cleanup
SF-N Nov 1, 2024
ff8a162
Fix skipping in test_execution
SF-N Nov 1, 2024
f31957d
Try to fix tox.ini
SF-N Nov 1, 2024
2bf3d11
Merge branch 'main' into optimize_program
SF-N Nov 1, 2024
f3dd7ab
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
e0db9fd
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
7a7c0e5
Further cleanup
SF-N Nov 4, 2024
271dd49
Run pre-commit
SF-N Nov 4, 2024
c4e5e32
Further cleanup
SF-N Nov 4, 2024
406d5e1
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
e908009
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
98ab673
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
8f15ee2
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
b7241bf
Add docstrings and move tests
SF-N Nov 4, 2024
c7bcc6a
Double-check add_content_to_fingerprint
SF-N Nov 5, 2024
3f76e88
Merge branch 'main' into optimize_program
SF-N Nov 5, 2024
87d4e94
Address review comments
SF-N Nov 5, 2024
23a1432
Update requirements
egparedes Nov 5, 2024
fd6bde3
Merge branch 'main' into optimize_program
SF-N Nov 7, 2024
9404b4a
Address review comments
SF-N Nov 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ repos:
- astunparse==1.6.3
- attrs==24.2.0
- black==24.8.0
- boltons==24.0.0
- boltons==24.1.0
- cached-property==2.0.1
- click==8.1.7
- cmake==3.30.5
- cytoolz==1.0.0
- deepdiff==8.0.1
- devtools==0.12.2
- diskcache==5.6.3
- factory-boy==3.3.1
- frozendict==2.4.6
- gridtools-cpp==2.3.6
Expand Down
9 changes: 5 additions & 4 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema,
babel==2.16.0 # via sphinx
backcall==0.2.0 # via ipython
black==24.8.0 # via gt4py (pyproject.toml)
boltons==24.0.0 # via gt4py (pyproject.toml)
boltons==24.1.0 # via gt4py (pyproject.toml)
bracex==2.5.post1 # via wcmatch
build==1.2.2.post1 # via pip-tools
bump-my-version==0.28.0 # via -r requirements-dev.in
bump-my-version==0.28.1 # via -r requirements-dev.in
cached-property==2.0.1 # via gt4py (pyproject.toml)
cachetools==5.5.0 # via tox
certifi==2024.8.30 # via requests
Expand All @@ -40,6 +40,7 @@ decorator==5.1.1 # via ipython
deepdiff==8.0.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.9 # via dace
diskcache==5.6.3 # via gt4py (pyproject.toml)
distlib==0.3.9 # via virtualenv
docutils==0.20.1 # via sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via hypothesis, pytest
Expand Down Expand Up @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client
questionary==2.0.1 # via bump-my-version
referencing==0.35.1 # via jsonschema, jsonschema-specifications
requests==2.32.3 # via sphinx
rich==13.9.3 # via bump-my-version, rich-click, tach
rich==13.9.4 # via bump-my-version, rich-click, tach
rich-click==1.8.3 # via bump-my-version
rpds-py==0.20.1 # via jsonschema, referencing
ruff==0.7.2 # via -r requirements-dev.in
Expand All @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython
stdlib-list==0.10.0 # via tach
sympy==1.12.1 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
tach==0.14.1 # via -r requirements-dev.in
tach==0.14.2 # via -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via tach
tomlkit==0.13.2 # via bump-my-version
Expand Down
1 change: 1 addition & 0 deletions min-extra-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dace==0.16.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
Expand Down
1 change: 1 addition & 0 deletions min-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ cytoolz==0.12.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
'cytoolz>=0.12.1',
'deepdiff>=5.6.0',
'devtools>=0.6',
'diskcache>=5.6.3',
SF-N marked this conversation as resolved.
Show resolved Hide resolved
'factory-boy>=3.3.0',
'frozendict>=2.3',
'gridtools-cpp>=2.3.6,==2.*',
Expand Down
9 changes: 5 additions & 4 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo
babel==2.16.0 # via -c constraints.txt, sphinx
backcall==0.2.0 # via -c constraints.txt, ipython
black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml)
bracex==2.5.post1 # via -c constraints.txt, wcmatch
build==1.2.2.post1 # via -c constraints.txt, pip-tools
bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in
bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in
cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
cachetools==5.5.0 # via -c constraints.txt, tox
certifi==2024.8.30 # via -c constraints.txt, requests
Expand All @@ -40,6 +40,7 @@ decorator==5.1.1 # via -c constraints.txt, ipython
deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml)
dill==0.3.9 # via -c constraints.txt, dace
diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml)
distlib==0.3.9 # via -c constraints.txt, virtualenv
docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest
Expand Down Expand Up @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client
questionary==2.0.1 # via -c constraints.txt, bump-my-version
referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications
requests==2.32.3 # via -c constraints.txt, sphinx
rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach
rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach
rich-click==1.8.3 # via -c constraints.txt, bump-my-version
rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing
ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in
Expand All @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython
stdlib-list==0.10.0 # via -c constraints.txt, tach
sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml)
tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in
tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via -c constraints.txt, tach
tomlkit==0.13.2 # via -c constraints.txt, bump-my-version
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG:
)


def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]:
def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]:
"""
Wrap `func_to_past` in a chainable and optionally cached workflow step.

Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No

@add_content_to_fingerprint.register(FieldOperatorDefinition)
@add_content_to_fingerprint.register(FoastOperatorDefinition)
@add_content_to_fingerprint.register(ProgramDefinition)
@add_content_to_fingerprint.register(PastProgramDefinition)
@add_content_to_fingerprint.register(toolchain.CompilableProgram)
@add_content_to_fingerprint.register(arguments.CompileTimeArgs)
Expand All @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo
for item in sourcedef:
add_content_to_fingerprint(item, hasher)

closure_vars = source_utils.get_closure_vars_from_function(obj)
for item in sorted(closure_vars.items(), key=lambda x: x[0]):
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
SF-N marked this conversation as resolved.
Show resolved Hide resolved
for key, value in obj.items():
for key, value in sorted(obj.items()):
add_content_to_fingerprint(key, hasher)
add_content_to_fingerprint(value, hasher)

Expand All @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint(
) -> None:
add_content_to_fingerprint(obj.location, hasher)
add_content_to_fingerprint(str(obj), hasher)
add_content_to_fingerprint(str(obj), hasher)
Copy link
Contributor Author

@SF-N SF-N Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, why this was done twice, does removing it make sense @egparedes / @DropD ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a bug to me and I agree with your change.

4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait):
closures: List[StencilClosure]
implicit_domain: bool = False

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]
_NODE_SYMBOLS_: ClassVar[List[Sym]] = [
Sym(id=name) for name in sorted(BUILTINS)
] # sorted for serialization stability


class Stmt(Node): ...
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dataclasses
import functools
import typing
from collections.abc import MutableMapping
from typing import Any, Callable, Generic, Protocol, TypeVar

from typing_extensions import Self
Expand Down Expand Up @@ -253,16 +254,15 @@ class CachedStep(

step: Workflow[StartT, EndT]
hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment]

_cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict)
cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict)

def __call__(self, inp: StartT) -> EndT:
"""Run the step only if the input is not cached, else return from cache."""
hash_ = self.hash_function(inp)
try:
result = self._cache[hash_]
result = self.cache[hash_]
except KeyError:
result = self._cache[hash_] = self.step(inp)
result = self.cache[hash_] = self.step(inp)
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def generate_stencil_source(
generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
else:
generated_code = GTFNCodegen.apply(gtfn_ir)

return codegen.format_source("cpp", generated_code, style="LLVM")

def __call__(
Expand Down
61 changes: 54 additions & 7 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

import functools
import warnings
from typing import Any
from typing import Any, Optional

import diskcache
import factory
import numpy.typing as npt

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve import utils
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next.iterator import transforms
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.iterator import ir as itir, transforms
from gt4py.next.otf import arguments, recipes, stages, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import compiler
Expand Down Expand Up @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int:
)


def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str:
"""
Generates a unique hash string for a stencil source program representing
the program, sorted offset_provider, and column_axis.
"""
program: itir.FencilDefinition | itir.Program = inp.data
offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider
column_axis: Optional[common.Dimension] = inp.args.column_axis

program_hash = utils.content_hash(
(
program,
sorted(offset_provider.items(), key=lambda el: el[0]),
SF-N marked this conversation as resolved.
Show resolved Hide resolved
column_axis,
)
)

return program_hash


SF-N marked this conversation as resolved.
Show resolved Hide resolved
class FileCache(diskcache.Cache):
SF-N marked this conversation as resolved.
Show resolved Hide resolved
"""
This class extends `diskcache.Cache` to ensure the cache is closed upon deletion,
i.e. it ensures that any resources associated with the cache are properly
released when the instance is garbage collected.
"""

def __del__(self) -> None:
self.close()


class GTFNCompileWorkflowFactory(factory.Factory):
class Meta:
model = recipes.OTFCompileWorkflow
Expand All @@ -129,10 +163,23 @@ class Params:
lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type)
)

translation = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)
cached_translation = factory.Trait(
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
o.translation_,
hash_function=fingerprint_compilable_program,
cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")),
)
),
)

translation_ = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)

translation = factory.LazyAttribute(lambda o: o.translation_)

bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = (
nanobind.bind_source
)
Expand Down Expand Up @@ -193,7 +240,7 @@ class Params:
name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True
)

run_gtfn_cached = GTFNBackendFactory(cached=True)
run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True)

run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

from gt4py.next.otf import languages, stages, workflow
from gt4py.next.otf.binding import interface
import numpy as np
import pytest
import diskcache
from gt4py.eve import SymbolName

import gt4py.next as gtx
from gt4py.next import (
Expand All @@ -30,7 +33,7 @@
from gt4py.next.program_processors.runners import gtfn
from gt4py.next.type_system import type_specifications as ts
from gt4py.next import utils as gt_utils

from gt4py.next import config
SF-N marked this conversation as resolved.
Show resolved Hide resolved
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
C2E,
Expand Down
Loading
Loading