Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RFC 31: Enumeration type safety. #957

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 221 additions & 18 deletions amaranth/lib/enum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import enum as py_enum
import warnings
import operator

from ..hdl.ast import Value, Shape, ShapeCastable, Const
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
from ..hdl._repr import *


__all__ = py_enum.__all__
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]


for _member in py_enum.__all__:
Expand All @@ -23,14 +24,18 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
:class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and
:meth:`as_shape` is never called, it places no restrictions on the enumeration class
or the values of its members.

When a :ref:`value-castable <lang-valuecasting>` is cast to an enum type that is an instance
of this metaclass, it can be automatically wrapped in a view class. A custom view class
can be specified by passing the ``view_class=`` keyword argument when creating the enum class.
"""

# TODO: remove this shim once py3.8 support is dropped
@classmethod
def __prepare__(metacls, name, bases, shape=None, **kwargs):
def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs):
return super().__prepare__(name, bases, **kwargs)

def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs):
if shape is not None:
shape = Shape.cast(shape)
# Prepare enumeration members for instantiation. This logic is unfortunately very
Expand Down Expand Up @@ -89,6 +94,8 @@ def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
# the values of every member can be cast to the provided shape without truncation.
cls._amaranth_shape_ = shape
if view_class is not None:
cls._amaranth_view_class_ = view_class
else:
# Shape is not provided explicitly. Behave the same as a standard enumeration;
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
Expand Down Expand Up @@ -127,17 +134,32 @@ def as_shape(cls):
return Shape._cast_plain_enum(cls)

def __call__(cls, value, *args, **kwargs):
# :class:`py_enum.Enum` uses ``__call__()`` for type casting: ``E(x)`` returns
# the enumeration member whose value equals ``x``. In this case, ``x`` must be a concrete
# value.
# Amaranth extends this to indefinite values, but conceptually the operation is the same:
# :class:`View` calls :meth:`Enum.__call__` to go from a :class:`Value` to something
# representing this enumeration with that value.
# At the moment however, for historical reasons, this is just the value itself. This works
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
# comparisons with enum members of the wrong type.
if isinstance(value, Value):
return value
"""Cast the value to this enum type.

When given an integer constant, it returns the corresponding enum value, like a standard
Python enumeration.

When given a :ref:`value-castable <lang-valuecasting>`, it is cast to a value, then wrapped
in the ``view_class`` specified for this enum type (:class:`EnumView` for :class:`Enum`,
:class:`FlagView` for :class:`Flag`, or a custom user-defined class). If the type has no
``view_class`` (like :class:`IntEnum` or :class:`IntFlag`), a plain
:class:`Value` is returned.

Returns
-------
instance of itself
For integer values, or instances of itself.
:class:`EnumView` or its subclass
For value-castables, as defined by the ``view_class`` keyword argument.
:class:`Value`
For value-castables, when a view class is not specified for this enum.
"""
if isinstance(value, (Value, ValueCastable)):
value = Value.cast(value)
if cls._amaranth_view_class_ is None:
return value
else:
return cls._amaranth_view_class_(cls, value)
return super().__call__(value, *args, **kwargs)

def const(cls, init):
Expand All @@ -149,15 +171,15 @@ def const(cls, init):
member = cls(0)
else:
member = cls(init)
return Const(member.value, cls.as_shape())
return cls(Const(member.value, cls.as_shape()))

def _value_repr(cls, value):
yield Repr(FormatEnum(cls), value)


class Enum(py_enum.Enum):
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
its metaclass."""
its metaclass and :class:`EnumView` as its view class."""


class IntEnum(py_enum.IntEnum):
Expand All @@ -167,16 +189,197 @@ class IntEnum(py_enum.IntEnum):

class Flag(py_enum.Flag):
"""Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as
its metaclass."""
its metaclass and :class:`FlagView` as its view class."""


class IntFlag(py_enum.IntFlag):
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
its metaclass."""


# Fix up the metaclass after the fact: the metaclass __new__ requires these classes
# to already be present, and also would not install itself on them due to lack of shape.
Enum.__class__ = EnumMeta
IntEnum.__class__ = EnumMeta
Flag.__class__ = EnumMeta
IntFlag.__class__ = EnumMeta


class EnumView(ValueCastable):
"""The view class used for :class:`Enum`.

Wraps a :class:`Value` and only allows type-safe operations. The only operators allowed are
equality comparisons (``==`` and ``!=``) with another :class:`EnumView` of the same enum type.
"""

def __init__(self, enum, target):
"""Constructs a view with the given enum type and target
(a :ref:`value-castable <lang-valuecasting>`).
"""
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}")
try:
cast_target = Value.cast(target)
except TypeError as e:
raise TypeError("EnumView target must be a value-castable object, not {!r}"
.format(target)) from e
if cast_target.shape() != enum.as_shape():
raise TypeError("EnumView target must have the same shape as the enum")
self.enum = enum
self.target = cast_target

def shape(self):
"""Returns the underlying enum type."""
return self.enum

@ValueCastable.lowermethod
def as_value(self):
"""Returns the underlying value."""
return self.target

def eq(self, other):
"""Assign to the underlying value.

Returns
-------
:class:`Assign`
``self.as_value().eq(other)``
"""
return self.as_value().eq(other)

def __add__(self, other):
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")

__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__mul__ = __add__
__rmul__ = __add__
__floordiv__ = __add__
__rfloordiv__ = __add__
__mod__ = __add__
__rmod__ = __add__
__lshift__ = __add__
__rlshift__ = __add__
__rshift__ = __add__
__rrshift__ = __add__
__lt__ = __add__
__le__ = __add__
__gt__ = __add__
__ge__ = __add__

def __and__(self, other):
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")

__rand__ = __and__
__or__ = __and__
__ror__ = __and__
__xor__ = __and__
__rxor__ = __and__

def __eq__(self, other):
"""Compares the underlying value for equality.

The other operand has to be either another :class:`EnumView` with the same enum type, or
a plain value of the underlying enum.

Returns
-------
:class:`Value`
The result of the equality comparison, as a single-bit value.
"""
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target == other.target

def __ne__(self, other):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target != other.target

def __repr__(self):
return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})"


class FlagView(EnumView):
"""The view class used for :class:`Flag`.

In addition to the operations allowed by :class:`EnumView`, it allows bitwise operations among
values of the same enum type."""

def __invert__(self):
"""Inverts all flags in this value and returns another :ref:`FlagView`.

Note that this is not equivalent to applying bitwise negation to the underlying value:
just like the Python :class:`enum.Flag` class, only bits corresponding to flags actually
defined in the enumeration are included in the result.

Returns
-------
:class:`FlagView`
"""
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
return self.enum._amaranth_view_class_(self.enum, ~self.target)
else:
singles_mask = 0
for flag in self.enum:
if (flag.value & (flag.value - 1)) == 0:
singles_mask |= flag.value
return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask)

def __bitop(self, other, op):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, FlagView) or other.enum is not self.enum:
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target))

def __and__(self, other):
"""Performs a bitwise AND and returns another :class:`FlagView`.

The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.

Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__and__)

def __or__(self, other):
"""Performs a bitwise OR and returns another :class:`FlagView`.

The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.

Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__or__)

def __xor__(self, other):
"""Performs a bitwise XOR and returns another :class:`FlagView`.

The other operand has to be either another :class:`FlagView` of the same enum type, or
a plain value of the underlying enum type.

Returns
-------
:class:`FlagView`
"""
return self.__bitop(other, operator.__xor__)

__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__


Enum._amaranth_view_class_ = EnumView
IntEnum._amaranth_view_class_ = None
Flag._amaranth_view_class_ = FlagView
IntFlag._amaranth_view_class_ = None
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Implemented RFCs
.. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
.. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html
.. _RFC 31: https://amaranth-lang.org/rfcs/0031-enumeration-type-safety.html


* `RFC 1`_: Aggregate data structure library
Expand All @@ -77,6 +78,7 @@ Implemented RFCs
* `RFC 20`_: Deprecate non-FWFT FIFOs
* `RFC 22`_: Define ``ValueCastable.shape()``
* `RFC 28`_: Allow overriding ``Value`` operators
* `RFC 31`_: Enumeration type safety


Language changes
Expand Down
59 changes: 59 additions & 0 deletions docs/stdlib/enum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ A shape can be specified for an enumeration with the ``shape=`` keyword argument

>>> Shape.cast(Funct)
unsigned(4)
>>> Value.cast(Funct.ADD)
(const 4'd0)

Any :ref:`constant-castable <lang-constcasting>` expression can be used as the value of a member:

Expand Down Expand Up @@ -57,6 +59,57 @@ The ``shape=`` argument is optional. If not specified, classes from this module

In this way, this module is a drop-in replacement for the standard :mod:`enum` module, and in an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``.

Signals with :class:`Enum` or :class:`Flag` based shape are automatically wrapped in the :class:`EnumView` or :class:`FlagView` value-castable wrappers, which ensure type safety. Any :ref:`value-castable <lang-valuecasting>` can also be explicitly wrapped in a view class by casting it to the enum type:

.. doctest::

>>> a = Signal(Funct)
>>> b = Signal(Op)
>>> type(a)
<class 'amaranth.lib.enum.EnumView'>
>>> a == b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: an EnumView can only be compared to value or other EnumView of the same enum type
>>> c = Signal(4)
>>> type(Funct(c))
<class 'amaranth.lib.enum.EnumView'>

Like the standard Python :class:`enum.IntEnum` and :class:`enum.IntFlag` classes, the Amaranth :class:`IntEnum` and :class:`IntFlag` classes are loosely typed and will not be subject to wrapping in view classes:

.. testcode::

class TransparentEnum(enum.IntEnum, shape=unsigned(4)):
FOO = 0
BAR = 1

.. doctest::

>>> a = Signal(TransparentEnum)
>>> type(a)
<class 'amaranth.hdl.ast.Signal'>

It is also possible to define a custom view class for a given enum:

.. testcode::

class InstrView(enum.EnumView):
def has_immediate(self):
return (self == Instr.ADDI) | (self == Instr.SUBI)

class Instr(enum.Enum, shape=5, view_class=InstrView):
ADD = Cat(Funct.ADD, Op.REG)
ADDI = Cat(Funct.ADD, Op.IMM)
SUB = Cat(Funct.SUB, Op.REG)
SUBI = Cat(Funct.SUB, Op.IMM)

.. doctest::

>>> a = Signal(Instr)
>>> type(a)
<class 'InstrView'>
>>> a.has_immediate()
(| (== (sig a) (const 5'd16)) (== (sig a) (const 5'd17)))

Metaclass
=========
Expand All @@ -71,3 +124,9 @@ Base classes
.. autoclass:: IntEnum()
.. autoclass:: Flag()
.. autoclass:: IntFlag()

View classes
============

.. autoclass:: EnumView()
.. autoclass:: FlagView()
Loading