Skip to content

Commit

Permalink
Implement RFC 31: Enumeration type safety.
Browse files Browse the repository at this point in the history
  • Loading branch information
wanda-phi committed Nov 28, 2023
1 parent b0b193f commit 229a4a8
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 11 deletions.
139 changes: 132 additions & 7 deletions amaranth/lib/enum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import enum as py_enum
import warnings
import operator

from ..hdl.ast import Value, Shape, ShapeCastable, Const
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
from ..hdl._repr import *


__all__ = py_enum.__all__
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]


for _member in py_enum.__all__:
Expand All @@ -27,10 +28,10 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):

# TODO: remove this shim once py3.8 support is dropped
@classmethod
def __prepare__(metacls, name, bases, shape=None, **kwargs):
def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs):
return super().__prepare__(name, bases, **kwargs)

def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs):
if shape is not None:
shape = Shape.cast(shape)
# Prepare enumeration members for instantiation. This logic is unfortunately very
Expand Down Expand Up @@ -89,6 +90,8 @@ def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
# the values of every member can be cast to the provided shape without truncation.
cls._amaranth_shape_ = shape
if view_class is not None:
cls._amaranth_view_class_ = view_class
else:
# Shape is not provided explicitly. Behave the same as a standard enumeration;
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
Expand Down Expand Up @@ -136,8 +139,12 @@ def __call__(cls, value, *args, **kwargs):
# At the moment however, for historical reasons, this is just the value itself. This works
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
# comparisons with enum members of the wrong type.
if isinstance(value, Value):
return value
if isinstance(value, (Value, ValueCastable)):
value = Value.cast(value)
if cls._amaranth_view_class_ is None:
return value
else:
return cls._amaranth_view_class_(cls, value)
return super().__call__(value, *args, **kwargs)

def const(cls, init):
Expand All @@ -149,7 +156,7 @@ def const(cls, init):
member = cls(0)
else:
member = cls(init)
return Const(member.value, cls.as_shape())
return cls(Const(member.value, cls.as_shape()))

def _value_repr(cls, value):
yield Repr(FormatEnum(cls), value)
Expand All @@ -174,9 +181,127 @@ class IntFlag(py_enum.IntFlag):
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
its metaclass."""


# Fix up the metaclass after the fact: the metaclass __new__ requires these classes
# to already be present, and also would not install itself on them due to lack of shape.
Enum.__class__ = EnumMeta
IntEnum.__class__ = EnumMeta
Flag.__class__ = EnumMeta
IntFlag.__class__ = EnumMeta


class EnumView(ValueCastable):
def __init__(self, enum, target):
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}")
try:
cast_target = Value.cast(target)
except TypeError as e:
raise TypeError("EnumView target must be a value-castable object, not {!r}"
.format(target)) from e
if cast_target.shape() != enum.as_shape():
raise TypeError("EnumView target must have the same shape as the enum")
self.enum = enum
self.target = cast_target

def shape(self):
return self.enum

@ValueCastable.lowermethod
def as_value(self):
return self.target

def eq(self, other):
"""Assign to the underlying value.
Returns
-------
:class:`Assign`
``self.as_value().eq(other)``
"""
return self.as_value().eq(other)

def __add__(self, other):
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")

__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__mul__ = __add__
__rmul__ = __add__
__floordiv__ = __add__
__rfloordiv__ = __add__
__mod__ = __add__
__rmod__ = __add__
__lshift__ = __add__
__rlshift__ = __add__
__rshift__ = __add__
__rrshift__ = __add__
__lt__ = __add__
__le__ = __add__
__gt__ = __add__
__ge__ = __add__

def __and__(self, other):
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")

__rand__ = __and__
__or__ = __and__
__ror__ = __and__
__xor__ = __and__
__rxor__ = __and__

def __eq__(self, other):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target == other.target

def __ne__(self, other):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, EnumView) or other.enum is not self.enum:
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
return self.target != other.target

def __repr__(self):
return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})"


class FlagView(EnumView):
def __invert__(self):
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
return self.enum._amaranth_view_class_(self.enum, ~self.target)
else:
singles_mask = 0
for flag in self.enum:
if (flag.value & (flag.value - 1)) == 0:
singles_mask |= flag.value
return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask)

def __bitop(self, other, op):
if isinstance(other, self.enum):
other = self.enum(Value.cast(other))
if not isinstance(other, FlagView) or other.enum is not self.enum:
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target))

def __and__(self, other):
return self.__bitop(other, operator.__and__)

def __or__(self, other):
return self.__bitop(other, operator.__or__)

def __xor__(self, other):
return self.__bitop(other, operator.__xor__)

__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__


Enum._amaranth_view_class_ = EnumView
IntEnum._amaranth_view_class_ = None
Flag._amaranth_view_class_ = FlagView
IntFlag._amaranth_view_class_ = None
178 changes: 174 additions & 4 deletions tests/test_lib_enum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum as py_enum
import operator
import sys

from amaranth import *
from amaranth.lib.enum import Enum, EnumMeta
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView

from .utils import *

Expand Down Expand Up @@ -103,9 +105,9 @@ def test_const_shape(self):
class EnumA(Enum, shape=8):
Z = 0
A = 10
self.assertRepr(EnumA.const(None), "(const 8'd0)")
self.assertRepr(EnumA.const(10), "(const 8'd10)")
self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)")
self.assertRepr(EnumA.const(None), "EnumView(EnumA, (const 8'd0))")
self.assertRepr(EnumA.const(10), "EnumView(EnumA, (const 8'd10))")
self.assertRepr(EnumA.const(EnumA.A), "EnumView(EnumA, (const 8'd10))")

def test_shape_implicit_wrong_in_concat(self):
class EnumA(Enum):
Expand All @@ -118,3 +120,171 @@ class EnumA(Enum):

def test_functional(self):
Enum("FOO", ["BAR", "BAZ"])

def test_int_enum(self):
class EnumA(IntEnum, shape=signed(4)):
A = 0
B = -3
a = Signal(EnumA)
self.assertRepr(a, "(sig a)")

def test_enum_view(self):
class EnumA(Enum, shape=signed(4)):
A = 0
B = -3
class EnumB(Enum, shape=signed(4)):
C = 0
D = 5
a = Signal(EnumA)
b = Signal(EnumB)
c = Signal(EnumA)
d = Signal(4)
self.assertIsInstance(a, EnumView)
self.assertIs(a.shape(), EnumA)
self.assertRepr(a, "EnumView(EnumA, (sig a))")
self.assertRepr(a.as_value(), "(sig a)")
self.assertRepr(a.eq(c), "(eq (sig a) (sig c))")
for op in [
operator.__add__,
operator.__sub__,
operator.__mul__,
operator.__floordiv__,
operator.__mod__,
operator.__lshift__,
operator.__rshift__,
operator.__and__,
operator.__or__,
operator.__xor__,
operator.__lt__,
operator.__le__,
operator.__gt__,
operator.__ge__,
]:
with self.assertRaises(TypeError):
op(a, a)
with self.assertRaises(TypeError):
op(a, d)
with self.assertRaises(TypeError):
op(d, a)
with self.assertRaises(TypeError):
op(a, 3)
with self.assertRaises(TypeError):
op(a, EnumA.A)
for op in [
operator.__eq__,
operator.__ne__,
]:
with self.assertRaises(TypeError):
op(a, b)
with self.assertRaises(TypeError):
op(a, d)
with self.assertRaises(TypeError):
op(d, a)
with self.assertRaises(TypeError):
op(a, 3)
with self.assertRaises(TypeError):
op(a, EnumB.C)
self.assertRepr(a == c, "(== (sig a) (sig c))")
self.assertRepr(a != c, "(!= (sig a) (sig c))")
self.assertRepr(a == EnumA.B, "(== (sig a) (const 4'sd-3))")
self.assertRepr(EnumA.B == a, "(== (sig a) (const 4'sd-3))")
self.assertRepr(a != EnumA.B, "(!= (sig a) (const 4'sd-3))")

def test_flag_view(self):
class FlagA(Flag, shape=unsigned(4)):
A = 1
B = 4
class FlagB(Flag, shape=unsigned(4)):
C = 1
D = 2
a = Signal(FlagA)
b = Signal(FlagB)
c = Signal(FlagA)
d = Signal(4)
self.assertIsInstance(a, FlagView)
self.assertRepr(a, "FlagView(FlagA, (sig a))")
for op in [
operator.__add__,
operator.__sub__,
operator.__mul__,
operator.__floordiv__,
operator.__mod__,
operator.__lshift__,
operator.__rshift__,
operator.__lt__,
operator.__le__,
operator.__gt__,
operator.__ge__,
]:
with self.assertRaises(TypeError):
op(a, a)
with self.assertRaises(TypeError):
op(a, d)
with self.assertRaises(TypeError):
op(d, a)
with self.assertRaises(TypeError):
op(a, 3)
with self.assertRaises(TypeError):
op(a, FlagA.A)
for op in [
operator.__eq__,
operator.__ne__,
operator.__and__,
operator.__or__,
operator.__xor__,
]:
with self.assertRaises(TypeError):
op(a, b)
with self.assertRaises(TypeError):
op(a, d)
with self.assertRaises(TypeError):
op(d, a)
with self.assertRaises(TypeError):
op(a, 3)
with self.assertRaises(TypeError):
op(a, FlagB.C)
self.assertRepr(a == c, "(== (sig a) (sig c))")
self.assertRepr(a != c, "(!= (sig a) (sig c))")
self.assertRepr(a == FlagA.B, "(== (sig a) (const 4'd4))")
self.assertRepr(FlagA.B == a, "(== (sig a) (const 4'd4))")
self.assertRepr(a != FlagA.B, "(!= (sig a) (const 4'd4))")
self.assertRepr(a | c, "FlagView(FlagA, (| (sig a) (sig c)))")
self.assertRepr(a & c, "FlagView(FlagA, (& (sig a) (sig c)))")
self.assertRepr(a ^ c, "FlagView(FlagA, (^ (sig a) (sig c)))")
self.assertRepr(~a, "FlagView(FlagA, (& (~ (sig a)) (const 3'd5)))")
self.assertRepr(a | FlagA.B, "FlagView(FlagA, (| (sig a) (const 4'd4)))")
if sys.version_info >= (3, 11):
class FlagC(Flag, shape=unsigned(4), boundary=py_enum.KEEP):
A = 1
B = 4
e = Signal(FlagC)
self.assertRepr(~e, "FlagView(FlagC, (~ (sig e)))")

def test_enum_view_wrong(self):
class EnumA(Enum, shape=signed(4)):
A = 0
B = -3

a = Signal(2)
with self.assertRaisesRegex(TypeError,
r'^EnumView target must have the same shape as the enum$'):
EnumA(a)
with self.assertRaisesRegex(TypeError,
r'^EnumView target must be a value-castable object, not .*$'):
EnumView(EnumA, "a")

class EnumB(Enum):
C = 0
D = 1
with self.assertRaisesRegex(TypeError,
r'^EnumView type must be an enum with shape, not .*$'):
EnumView(EnumB, 3)

def test_enum_view_custom(self):
class CustomView(EnumView):
pass
class EnumA(Enum, view_class=CustomView, shape=unsigned(2)):
A = 0
B = 1
a = Signal(EnumA)
assert isinstance(a, CustomView)

0 comments on commit 229a4a8

Please sign in to comment.