Skip to content

Commit

Permalink
refactor[next]: remove use of Fencil in tracing (eliminate closure) (
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt authored Dec 6, 2024
1 parent 39fb949 commit 8b6abc2
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 57 deletions.
11 changes: 6 additions & 5 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,8 +1706,10 @@ def impl(*iters: ItIterator):
return impl


def _dimension_to_tag(domain: Domain) -> dict[Tag, range]:
return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()}
def _dimension_to_tag(
domain: runtime.CartesianDomain | runtime.UnstructuredDomain,
) -> dict[Tag, range]:
return {k.value: v for k, v in domain.items()}


def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None:
Expand Down Expand Up @@ -1828,7 +1830,7 @@ def impl(*args):

# TODO(havogt): after updating all tests to use the new program,
# we should get rid of closure and move the implementation to this function
closure(_dimension_to_tag(domain), fun, out, list(args))
closure(domain, fun, out, list(args))
return out

return impl
Expand All @@ -1839,9 +1841,8 @@ def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
sten: Callable[..., Any],
out, #: MutableLocatedField,
ins: list[common.Field | Scalar | tuple[common.Field | Scalar | tuple, ...]],
Expand Down
9 changes: 2 additions & 7 deletions src/gt4py/next/iterator/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# TODO(tehrengruber): remove cirular dependency and import unconditionally
from gt4py.next import backend as next_backend

__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"]
__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"]


@dataclass(frozen=True)
Expand Down Expand Up @@ -163,7 +163,7 @@ def impl(out, *inps):
# if passed as a dict, we need to convert back to builtins for interpretation by the backends
assert offset_provider is not None
dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider))
closure(dom, self.fundef_dispatcher, out, [*inps])
set_at(builtins.as_fieldop(self.fundef_dispatcher, dom)(*inps), dom, out)

return impl

Expand Down Expand Up @@ -208,11 +208,6 @@ def fundef(fun):
return FundefDispatcher(fun)


@builtin_dispatch
def closure(*args): # TODO remove
return BackendNotSelectedError()


@builtin_dispatch
def set_at(*args):
return BackendNotSelectedError()
Expand Down
44 changes: 7 additions & 37 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Lambda,
NoneLiteral,
OffsetLiteral,
StencilClosure,
Sym,
SymRef,
)
Expand Down Expand Up @@ -202,20 +201,13 @@ def __bool__(self):

class TracerContext:
fundefs: ClassVar[List[FunctionDefinition]] = []
closures: ClassVar[
List[StencilClosure]
] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils
body: ClassVar[List[itir.Stmt]] = []

@classmethod
def add_fundef(cls, fun):
if fun not in cls.fundefs:
cls.fundefs.append(fun)

@classmethod
def add_closure(cls, closure):
cls.closures.append(closure)

@classmethod
def add_stmt(cls, stmt):
cls.body.append(stmt)
Expand All @@ -225,23 +217,10 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, exc_traceback):
type(self).fundefs = []
type(self).closures = []
type(self).body = []
iterator.builtins.builtin_dispatch.pop_key()


@iterator.runtime.closure.register(TRACING)
def closure(domain, stencil, output, inputs):
if hasattr(stencil, "__name__") and stencil.__name__ in iterator.builtins.__all__:
stencil = _s(stencil.__name__)
else:
stencil(*(_s(param) for param in inspect.signature(stencil).parameters))
stencil = make_node(stencil)
TracerContext.add_closure(
StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs)
)


@iterator.runtime.set_at.register(TRACING)
def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None:
TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target))
Expand Down Expand Up @@ -328,19 +307,10 @@ def trace_fencil_definition(
params = _make_fencil_params(fun, args)
trace_function_call(fun, args=(_s(param.id) for param in params))

if TracerContext.closures:
return itir.FencilDefinition(
id=fun.__name__,
function_definitions=TracerContext.fundefs,
params=params,
closures=TracerContext.closures,
)
else:
assert TracerContext.body
return itir.Program(
id=fun.__name__,
function_definitions=TracerContext.fundefs,
params=params,
declarations=[], # TODO
body=TracerContext.body,
)
return itir.Program(
id=fun.__name__,
function_definitions=TracerContext.fundefs,
params=params,
declarations=[], # TODO
body=TracerContext.body,
)
1 change: 0 additions & 1 deletion src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator):
AxisLiteral = as_fmt("{value}")
FunCall = as_fmt("{fun}({','.join(args)})")
Lambda = as_mako("(lambda ${','.join(params)}: ${expr})")
StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])")
FunctionDefinition = as_mako(
"""
@fundef
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gt4py.next.iterator import builtins as it_builtins
from gt4py.next.iterator.builtins import (
and_,
as_fieldop,
bool,
can_deref,
cartesian_domain,
Expand Down Expand Up @@ -45,9 +46,8 @@
plus,
shift,
xor_,
as_fieldop,
)
from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset
from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at
from gt4py.next.program_processors.runners.gtfn import run_gtfn

from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data
Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]:
(None, True),
(next_tests.definitions.ProgramBackendId.ROUNDTRIP, True),
(next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True),
(next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, True),
(next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True),
(next_tests.definitions.ProgramBackendId.GTFN_CPU, True),
(next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@ def foo(inp):
dtype=None,
)

I = gtx.Dimension("I")


def test_deduce_domain():
assert isinstance(_deduce_domain({}, {}), CartesianDomain)
assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain)
assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain)
assert isinstance(
_deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain
_deduce_domain(CartesianDomain([(I, range(1))]), {"foo": connectivity}), CartesianDomain
)


I = gtx.Dimension("I")


def test_embedded_error_on_wrong_domain():
dom = CartesianDomain([("I", range(1))])
dom = CartesianDomain([(I, range(1))])

out = gtx.as_field([I], np.zeros(1))
with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"):
Expand Down

0 comments on commit 8b6abc2

Please sign in to comment.