Skip to content

Commit

Permalink
Replace wrong fixes by correct ones
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Oct 25, 2023
1 parent cdc9853 commit a5e1095
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 74 deletions.
67 changes: 0 additions & 67 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,73 +228,6 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]:
return lambda obj: getitem_(obj, key, default=default)


_C = TypeVar("_C")
_V = TypeVar("_V")

_dataclass: Final[Callable] = (
functools.partial(dataclasses.dataclass, slots=True)
if sys.version_info >= (3, 10)
else dataclasses.dataclass
)


@_dataclass(frozen=True, slots=True)
class ForwardDescriptor(xtyping.NonDataDescriptor[_C, _V]):
"""
Descriptor to forward attribute access to another member of the object.
Args:
source_member: name of the member to forward the attribute access to.
attribute_name: name of the attribute to be forwarded. If `None`,
the name of the descriptor in the owner class is used.
Examples:
>>> class A:
... def __init__(self, value):
... self.value = value
...
>>> class B:
... def __init__(self, a):
... self.a = a
...
... value = ForwardDescriptor('a')
...
>>> a = A(10)
>>> b = B(a)
>>> b.value
10
"""

source_member: str
attribute_name: Optional[str] = None

def __set_name__(self, _owner_type: _C, _name: str) -> None:
if self.attribute_name is None:
object.__setattr__(self, "attribute_name", _name)

@overload
def __get__(
self, _instance: Literal[None], _owner_type: Optional[Type[_C]] = None
) -> ForwardDescriptor[_C, _V]:
...

@overload
def __get__( # noqa: F811 # redefinion of unused member
self, _instance: _C, _owner_type: Optional[Type[_C]] = None
) -> _V:
...

def __get__( # noqa: F811 # redefinion of unused member
self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None
) -> _V | ForwardDescriptor[_C, _V]:
assert self.attribute_name is not None
return (
getattr(getattr(_instance, self.source_member), self.attribute_name)
if _instance is not None
else self
)


_P = ParamSpec("_P")


Expand Down
2 changes: 0 additions & 2 deletions src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def __gt_allocate__(
shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index
)

__call__ = __gt_allocate__


if TYPE_CHECKING:
__TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[FieldAllocatorInterface] = FieldAllocator
Expand Down
16 changes: 11 additions & 5 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -173,6 +174,9 @@ class Program:
backend: Optional[ppi.ProgramExecutor] = None
grid_type: Optional[GridType] = None

__gt_device_type__: next_allocators.FieldAllocatorInterface.__gt_device_type__
__gt_allocate__: next_allocators.FieldAllocatorInterface.__gt_allocate__

@classmethod
def from_function(
cls,
Expand Down Expand Up @@ -213,11 +217,13 @@ def __post_init__(self):
raise RuntimeError(
f"The following closure variables are undefined: {', '.join(undefined_symbols)}"
)
if self.backend is not None and hasattr(self.backend, "__gt_allocate__"):
object.__setattr__(self, "__gt_allocate__", self.backend.__gt_allocate__)

__gt_device_type__ = eve_utils.ForwardDescriptor("backend")
__gt_allocate__ = eve_utils.ForwardDescriptor("backend")
if self.backend is not None:
object.__setattr__(
self, "__gt_device_type__", getattr(self.backend, "__gt_device_type__", None)
)
object.__setattr__(
self, "__gt_allocate__", getattr(self.backend, "__gt_allocate__", None)
)

def with_backend(self, backend: ppi.ProgramExecutor) -> Program:
return dataclasses.replace(self, backend=backend)
Expand Down

0 comments on commit a5e1095

Please sign in to comment.