Skip to content

Commit

Permalink
Use registry mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Aug 29, 2023
1 parent 9c1e9e4 commit 35fecfc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
37 changes: 17 additions & 20 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from __future__ import annotations

import abc
import collections
import dataclasses
import enum
import functools
import sys
from collections.abc import Sequence, Set
from types import EllipsisType
from typing import Type, TypeGuard, overload
from typing import ChainMap, TypeGuard, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -233,10 +234,6 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
class Field(Protocol[DimsT, core_defs.ScalarT]):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]

if TYPE_CHECKING:
register_builtin_func: Callable
_builtin_func_map: dict[fbuiltins.BuiltInFunction, Callable]

@property
def domain(self) -> Domain:
...
Expand Down Expand Up @@ -487,21 +484,21 @@ def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]:
)


def enable_builtin_func_registry(cls: Type[Field]) -> Type[Field]:
cls._builtin_func_map = {}
setattr(cls, "register_builtin_func", classmethod(register_builtin_func))
setattr(cls, "__gt_builtin_func__", classmethod(__gt_builtin_func__))
return cls

class FieldBuiltinFuncRegistry:
_builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap()

def register_builtin_func(
cls, /, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None
) -> Any:
assert op not in cls._builtin_func_map
if op_func is None: # when used as a decorator
return functools.partial(cls.register_builtin_func, op)
return cls._builtin_func_map.setdefault(op, op_func)
def __init_subclass__(cls, **kwargs):
cls._builtin_func_map = cls._builtin_func_map.new_child()

@classmethod
def register_builtin_func(
cls, /, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None
) -> Any:
assert op not in cls._builtin_func_map
if op_func is None: # when used as a decorator
return functools.partial(cls.register_builtin_func, op)
return cls._builtin_func_map.setdefault(op, op_func)

def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Any:
return cls._builtin_func_map.get(func, NotImplemented)
@classmethod
def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Any:
return cls._builtin_func_map.get(func, NotImplemented)
4 changes: 2 additions & 2 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.common import FieldBuiltinFuncRegistry
from gt4py.next.ffront import fbuiltins


Expand Down Expand Up @@ -80,9 +81,8 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field:
_R = TypeVar("_R", _Value, tuple[_Value, ...])


@common.enable_builtin_func_registry
@dataclasses.dataclass(frozen=True)
class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]):
class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry):
"""
Shared field implementation for NumPy-like fields.
Expand Down

0 comments on commit 35fecfc

Please sign in to comment.