Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Add memory and disk-based caching to more workflow steps #1690

Merged
merged 67 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
7e78704
Fix lowering_utils._expr_hash stability across runs
tehrengruber Feb 28, 2024
7515d07
Pass manager caching
tehrengruber Feb 28, 2024
54bbea0
Fix format
tehrengruber Feb 28, 2024
0cfc583
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
04d8459
Use content_hash
tehrengruber Feb 28, 2024
eac113c
Fix import
tehrengruber Feb 28, 2024
679521a
Small fix
tehrengruber Feb 28, 2024
2e3f17b
Small fix
tehrengruber Feb 28, 2024
2803f41
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
d6adb0e
Small fix
tehrengruber Feb 28, 2024
b39b22f
Small cleanup
tehrengruber Feb 28, 2024
06c0384
Merge branch 'fix_caching2' into pass_manager_caching
tehrengruber Feb 28, 2024
bbc3b4d
Fix cache dir creation
tehrengruber May 14, 2024
3e0b021
Merge origin/main
tehrengruber Jun 21, 2024
8f5aedd
Add todo
tehrengruber Jun 21, 2024
e039057
Fix typo
tehrengruber Jun 21, 2024
5a8b525
Cleanup
tehrengruber Jun 21, 2024
cf213c4
Add tests
tehrengruber Jun 25, 2024
b724a8d
Merge pull request #1 from GridTools/main
SF-N Jul 8, 2024
6bd5652
Merge remote-tracking branch 'origin/main'
SF-N Oct 10, 2024
ff055f1
Merge remote-tracking branch 'origin/main'
SF-N Oct 15, 2024
cdfd6b5
Extend lap_test to also run benchmark
SF-N Oct 15, 2024
6806f01
Add pytest-benchmark to requirements-dev
SF-N Oct 16, 2024
fb93b73
Disable pytest-xdist to enable benchmarking
SF-N Oct 16, 2024
ad76088
remove -n from tox.ini
SF-N Oct 16, 2024
256d9de
Enable pytest-xdist and run benchmarks separately
SF-N Oct 16, 2024
f00c60a
Change order for faster CI results
SF-N Oct 16, 2024
ccc35d0
Add pytest flags and reformat
SF-N Oct 16, 2024
780a955
Use larger grid for lap_benchmark
SF-N Oct 16, 2024
3effb43
Add pytest.ini for benchmark comparison
SF-N Oct 16, 2024
01bfda8
Add pygal to requirements for histogram
SF-N Oct 16, 2024
9995bec
Remove benchmark comparison
SF-N Oct 16, 2024
afd3e96
Set cached = True in func_to_past_factory and reformat
SF-N Oct 16, 2024
f72c84f
Set cached = False in func_to_past_factory because of missing perform…
SF-N Oct 16, 2024
736a924
WIP remote debugging session
egparedes Oct 16, 2024
833c1e9
Remove debugging output
SF-N Oct 16, 2024
1d3310c
Add further performance improvements with try-except in gtfn and backend
SF-N Oct 16, 2024
90414cc
Revert changes in gtfn.py
SF-N Oct 24, 2024
570385c
Merge origin/main
SF-N Oct 24, 2024
94fcb90
Merge branch 'main' into optimize_program
SF-N Oct 24, 2024
d856bb3
Merge branch 'main' into optimize_program
SF-N Oct 28, 2024
77f4d2f
Merge remote-tracking branch 'origin_tehrengruber/pass_manager_cachin…
SF-N Oct 28, 2024
174c7fc
File caching with CachedStep (still commented in gtfn because of erro…
SF-N Oct 29, 2024
26c7060
Merge branch 'main' into optimize_program
SF-N Oct 30, 2024
7aba4bd
Use diskcache instead of shelve, cleanup and add new test
SF-N Oct 31, 2024
28eb4da
Update caching and tests
SF-N Oct 31, 2024
27a141b
Remove test
SF-N Oct 31, 2024
dc04786
Cleanup
SF-N Nov 1, 2024
ff8a162
Fix skipping in test_execution
SF-N Nov 1, 2024
f31957d
Try to fix tox.ini
SF-N Nov 1, 2024
2bf3d11
Merge branch 'main' into optimize_program
SF-N Nov 1, 2024
f3dd7ab
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
e0db9fd
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
7a7c0e5
Further cleanup
SF-N Nov 4, 2024
271dd49
Run pre-commit
SF-N Nov 4, 2024
c4e5e32
Further cleanup
SF-N Nov 4, 2024
406d5e1
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
e908009
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
98ab673
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
8f15ee2
Merge branch 'main' into optimize_program
SF-N Nov 4, 2024
b7241bf
Add docstrings and move tests
SF-N Nov 4, 2024
c7bcc6a
Double-check add_content_to_fingerprint
SF-N Nov 5, 2024
3f76e88
Merge branch 'main' into optimize_program
SF-N Nov 5, 2024
87d4e94
Address review comments
SF-N Nov 5, 2024
23a1432
Update requirements
egparedes Nov 5, 2024
fd6bde3
Merge branch 'main' into optimize_program
SF-N Nov 7, 2024
9404b4a
Address review comments
SF-N Nov 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
'cytoolz>=0.12.1',
'deepdiff>=5.6.0',
'devtools>=0.6',
'diskcache>=5.6.3',
'factory-boy>=3.3.0',
'frozendict>=2.3',
'gridtools-cpp>=2.3.6,==2.*',
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def env_flag_to_bool(name: str, default: bool) -> bool:
)


GTFN_SOURCE_CACHE_DIR: str = os.environ.get(f"{_PREFIX}_GTFN_SOURCE_CACHE_DIR", "gtfn_cache")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to avoid exposing a config option that is special to a certain backend. Is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We introduced this since we wanted to avoid hardcoding it. In almost all cases, this should not change but e.g. in the tests, where the cache should be deleted afterwards, specifying a different directory would make sense.
Do you have a suggestion on how to make this configurable which you would prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DaCe has also a cache directory, maybe we could use a common config option for compiled backends?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good to me. Do you agree @havogt?
And do you have a name in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally proposed this, but after thinking about it again, and together with my other suggestion about how to deal with this in tests, I think it's not needed at least for this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I removed it and hardcoded it in gtfn.py.



#: Whether generated code projects should be kept around between runs.
#: - SESSION: generated code projects get destroyed when the interpreter shuts down
#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG:
)


def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]:
def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]:
"""
Wrap `func_to_past` in a chainable and optionally cached workflow step.

Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No

@add_content_to_fingerprint.register(FieldOperatorDefinition)
@add_content_to_fingerprint.register(FoastOperatorDefinition)
@add_content_to_fingerprint.register(ProgramDefinition)
@add_content_to_fingerprint.register(PastProgramDefinition)
@add_content_to_fingerprint.register(toolchain.CompilableProgram)
@add_content_to_fingerprint.register(arguments.CompileTimeArgs)
Expand All @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo
for item in sourcedef:
add_content_to_fingerprint(item, hasher)

closure_vars = source_utils.get_closure_vars_from_function(obj)
for item in sorted(closure_vars.items(), key=lambda x: x[0]):
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
for key, value in obj.items():
for key, value in sorted(obj.items()):
add_content_to_fingerprint(key, hasher)
add_content_to_fingerprint(value, hasher)

Expand All @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint(
) -> None:
add_content_to_fingerprint(obj.location, hasher)
add_content_to_fingerprint(str(obj), hasher)
add_content_to_fingerprint(str(obj), hasher)
Copy link
Contributor Author

@SF-N SF-N Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, why this was done twice, does removing it make sense @egparedes / @DropD ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a bug to me and I agree with your change.

4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait):
closures: List[StencilClosure]
implicit_domain: bool = False

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]
_NODE_SYMBOLS_: ClassVar[List[Sym]] = [
Sym(id=name) for name in sorted(BUILTINS)
] # sorted for serialization stability


class Stmt(Node): ...
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dataclasses
import functools
import typing
from collections.abc import MutableMapping
from typing import Any, Callable, Generic, Protocol, TypeVar

from typing_extensions import Self
Expand Down Expand Up @@ -253,16 +254,15 @@ class CachedStep(

step: Workflow[StartT, EndT]
hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment]

_cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict)
cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict)

def __call__(self, inp: StartT) -> EndT:
"""Run the step only if the input is not cached, else return from cache."""
hash_ = self.hash_function(inp)
try:
result = self._cache[hash_]
result = self.cache[hash_]
except KeyError:
result = self._cache[hash_] = self.step(inp)
result = self.cache[hash_] = self.step(inp)
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def generate_stencil_source(
generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
else:
generated_code = GTFNCodegen.apply(gtfn_ir)

return codegen.format_source("cpp", generated_code, style="LLVM")

def __call__(
Expand Down
61 changes: 54 additions & 7 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

import functools
import warnings
from typing import Any
from typing import Any, Optional

import diskcache
import factory
import numpy.typing as npt

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve import utils
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next.iterator import transforms
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.iterator import ir as itir, transforms
from gt4py.next.otf import arguments, recipes, stages, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import compiler
Expand Down Expand Up @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int:
)


def generate_stencil_source_hash_function(inp: stages.CompilableProgram) -> str:
"""
Generates a unique hash string for a stencil source program representing
the program, sorted offset_provider, and column_axis.
"""
program: itir.FencilDefinition | itir.Program = inp.data
offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider
column_axis: Optional[common.Dimension] = inp.args.column_axis

program_hash = utils.content_hash(
(
program,
sorted(offset_provider.items(), key=lambda el: el[0]),
column_axis,
)
)

return program_hash


class FileCache(diskcache.Cache):
"""
This class extends `diskcache.Cache` to ensure the cache is closed upon deletion,
i.e. it ensures that any resources associated with the cache are properly
released when the instance is garbage collected.
"""

def __del__(self) -> None:
self.close()


class GTFNCompileWorkflowFactory(factory.Factory):
class Meta:
model = recipes.OTFCompileWorkflow
Expand All @@ -129,10 +163,23 @@ class Params:
lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type)
)

translation = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)
cached_translation = factory.Trait(
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
o.translation_,
hash_function=generate_stencil_source_hash_function,
cache=FileCache(str(config.BUILD_CACHE_DIR / config.GTFN_SOURCE_CACHE_DIR)),
)
),
)

translation_ = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)

translation = factory.LazyAttribute(lambda o: o.translation_)

bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = (
nanobind.bind_source
)
Expand Down Expand Up @@ -193,7 +240,7 @@ class Params:
name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True
)

run_gtfn_cached = GTFNBackendFactory(cached=True)
run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True)

run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

from gt4py.next.otf import languages, stages, workflow
from gt4py.next.otf.binding import interface
import numpy as np
import pytest
import diskcache
from gt4py.eve import SymbolName

import gt4py.next as gtx
from gt4py.next import (
Expand All @@ -30,7 +33,7 @@
from gt4py.next.program_processors.runners import gtfn
from gt4py.next.type_system import type_specifications as ts
from gt4py.next import utils as gt_utils

from gt4py.next import config
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
C2E,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,29 @@

import numpy as np
import pytest
import tempfile
import pathlib
import os
import pickle
import copy
import diskcache


import gt4py.next as gtx
import gt4py.next.config
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.otf import arguments, languages, stages
from gt4py.next.otf import arguments, languages, stages, workflow, toolchain
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
from gt4py.next.program_processors.runners import gtfn
from gt4py.next.type_system import type_translation
from next_tests.integration_tests import cases

from next_tests.integration_tests.cases import cartesian_case

from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
)


@pytest.fixture
Expand Down Expand Up @@ -71,3 +87,89 @@ def test_codegen(fencil_example):
assert module.entry_point.name == fencil.id
assert any(d.name == "gridtools_cpu" for d in module.library_deps)
assert module.language is languages.CPP


def test_hash_and_diskcache(fencil_example):
fencil, parameters = fencil_example
compilable_program = stages.CompilableProgram(
data=fencil,
args=arguments.CompileTimeArgs.from_concrete_no_size(
*parameters, **{"offset_provider": {}}
),
)

hash = gtfn.generate_stencil_source_hash_function(compilable_program)
path = str(gt4py.next.config.BUILD_CACHE_DIR / gt4py.next.config.GTFN_SOURCE_CACHE_DIR)
with diskcache.Cache(path) as cache:
cache[hash] = compilable_program

# check content of cash file
with diskcache.Cache(path) as reopened_cache:
assert hash in reopened_cache
compilable_program_from_cache = reopened_cache[hash]
assert compilable_program == compilable_program_from_cache
del reopened_cache[hash] # delete data

# hash creation is deterministic
assert hash == gtfn.generate_stencil_source_hash_function(compilable_program)
assert hash == gtfn.generate_stencil_source_hash_function(compilable_program_from_cache)

# hash is different if program changes
altered_program = copy.deepcopy(compilable_program)
altered_program.data.id = "example2"
assert gtfn.generate_stencil_source_hash_function(
compilable_program
) != gtfn.generate_stencil_source_hash_function(altered_program)


def test_gtfn_file_cache(fencil_example):
fencil, parameters = fencil_example
compilable_program = stages.CompilableProgram(
data=fencil,
args=arguments.CompileTimeArgs.from_concrete_no_size(
*parameters, **{"offset_provider": {}}
),
)
cached_gtfn_translation_step = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=True
).executor.step.translation

bare_gtfn_translation_step = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=False
).executor.step.translation

cached_gtfn_translation_step(
compilable_program
) # run cached translation step once to populate cache
assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step(
compilable_program
)

cache_key = gtfn.generate_stencil_source_hash_function(compilable_program)
assert cache_key in cached_gtfn_translation_step.cache
assert (
bare_gtfn_translation_step(compilable_program)
== cached_gtfn_translation_step.cache[cache_key]
)


def test_gtfn_file_cache_whole_workflow(cartesian_case):
if cartesian_case.backend != gtfn.run_gtfn:
pytest.skip("Skipping backend.")
cartesian_case.backend = gtfn.GTFNBackendFactory(
gpu=False, cached=True, otf_workflow__cached_translation=True
)

@gtx.field_operator
def testee(a: cases.IJKField) -> cases.IJKField:
field_tuple = (a, a)
field_0 = field_tuple[0]
field_1 = field_tuple[1]
return field_0

# first call: this generates the cache file
cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)
# clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again
object.__setattr__(cartesian_case.backend.executor, "cache", {})
# second call: the cache file is used
cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)
Loading