diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 02e9dfe603..f59d2f4120 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -1,11 +1,12 @@ import enum as py_enum import warnings +import operator -from ..hdl.ast import Value, Shape, ShapeCastable, Const +from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const from ..hdl._repr import * -__all__ = py_enum.__all__ +__all__ = py_enum.__all__ + ["EnumView", "FlagView"] for _member in py_enum.__all__: @@ -27,10 +28,10 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta): # TODO: remove this shim once py3.8 support is dropped @classmethod - def __prepare__(metacls, name, bases, shape=None, **kwargs): + def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs): return super().__prepare__(name, bases, **kwargs) - def __new__(metacls, name, bases, namespace, shape=None, **kwargs): + def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs): if shape is not None: shape = Shape.cast(shape) # Prepare enumeration members for instantiation. This logic is unfortunately very @@ -89,6 +90,8 @@ def __new__(metacls, name, bases, namespace, shape=None, **kwargs): # Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that # the values of every member can be cast to the provided shape without truncation. cls._amaranth_shape_ = shape + if view_class is not None: + cls._amaranth_view_class_ = view_class else: # Shape is not provided explicitly. Behave the same as a standard enumeration; # the lack of `_amaranth_shape_` attribute is used to emit a warning when such @@ -136,8 +139,12 @@ def __call__(cls, value, *args, **kwargs): # At the moment however, for historical reasons, this is just the value itself. This works # and is backwards-compatible but is limiting in that it does not allow us to e.g. catch # comparisons with enum members of the wrong type. - if isinstance(value, Value): - return value + if isinstance(value, (Value, ValueCastable)): + value = Value.cast(value) + if cls._amaranth_view_class_ is None: + return value + else: + return cls._amaranth_view_class_(cls, value) return super().__call__(value, *args, **kwargs) def const(cls, init): @@ -149,7 +156,7 @@ def const(cls, init): member = cls(0) else: member = cls(init) - return Const(member.value, cls.as_shape()) + return cls(Const(member.value, cls.as_shape())) def _value_repr(cls, value): yield Repr(FormatEnum(cls), value) @@ -174,9 +181,127 @@ class IntFlag(py_enum.IntFlag): """Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as its metaclass.""" + # Fix up the metaclass after the fact: the metaclass __new__ requires these classes # to already be present, and also would not install itself on them due to lack of shape. Enum.__class__ = EnumMeta IntEnum.__class__ = EnumMeta Flag.__class__ = EnumMeta IntFlag.__class__ = EnumMeta + + +class EnumView(ValueCastable): + def __init__(self, enum, target): + if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"): + raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}") + try: + cast_target = Value.cast(target) + except TypeError as e: + raise TypeError("EnumView target must be a value-castable object, not {!r}" + .format(target)) from e + if cast_target.shape() != enum.as_shape(): + raise TypeError("EnumView target must have the same shape as the enum") + self.enum = enum + self.target = cast_target + + def shape(self): + return self.enum + + @ValueCastable.lowermethod + def as_value(self): + return self.target + + def eq(self, other): + """Assign to the underlying value. + + Returns + ------- + :class:`Assign` + ``self.as_value().eq(other)`` + """ + return self.as_value().eq(other) + + def __add__(self, other): + raise TypeError("cannot perform arithmetic operations on non-IntEnum enum") + + __radd__ = __add__ + __sub__ = __add__ + __rsub__ = __add__ + __mul__ = __add__ + __rmul__ = __add__ + __floordiv__ = __add__ + __rfloordiv__ = __add__ + __mod__ = __add__ + __rmod__ = __add__ + __lshift__ = __add__ + __rlshift__ = __add__ + __rshift__ = __add__ + __rrshift__ = __add__ + __lt__ = __add__ + __le__ = __add__ + __gt__ = __add__ + __ge__ = __add__ + + def __and__(self, other): + raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum") + + __rand__ = __and__ + __or__ = __and__ + __ror__ = __and__ + __xor__ = __and__ + __rxor__ = __and__ + + def __eq__(self, other): + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, EnumView) or other.enum is not self.enum: + raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") + return self.target == other.target + + def __ne__(self, other): + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, EnumView) or other.enum is not self.enum: + raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") + return self.target != other.target + + def __repr__(self): + return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})" + + +class FlagView(EnumView): + def __invert__(self): + if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP): + return self.enum._amaranth_view_class_(self.enum, ~self.target) + else: + singles_mask = 0 + for flag in self.enum: + if (flag.value & (flag.value - 1)) == 0: + singles_mask |= flag.value + return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask) + + def __bitop(self, other, op): + if isinstance(other, self.enum): + other = self.enum(Value.cast(other)) + if not isinstance(other, FlagView) or other.enum is not self.enum: + raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type") + return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target)) + + def __and__(self, other): + return self.__bitop(other, operator.__and__) + + def __or__(self, other): + return self.__bitop(other, operator.__or__) + + def __xor__(self, other): + return self.__bitop(other, operator.__xor__) + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ + + +Enum._amaranth_view_class_ = EnumView +IntEnum._amaranth_view_class_ = None +Flag._amaranth_view_class_ = FlagView +IntFlag._amaranth_view_class_ = None diff --git a/tests/test_lib_enum.py b/tests/test_lib_enum.py index 99425ffcc8..83dd038305 100644 --- a/tests/test_lib_enum.py +++ b/tests/test_lib_enum.py @@ -1,7 +1,9 @@ import enum as py_enum +import operator +import sys from amaranth import * -from amaranth.lib.enum import Enum, EnumMeta +from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView from .utils import * @@ -103,9 +105,9 @@ def test_const_shape(self): class EnumA(Enum, shape=8): Z = 0 A = 10 - self.assertRepr(EnumA.const(None), "(const 8'd0)") - self.assertRepr(EnumA.const(10), "(const 8'd10)") - self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)") + self.assertRepr(EnumA.const(None), "EnumView(EnumA, (const 8'd0))") + self.assertRepr(EnumA.const(10), "EnumView(EnumA, (const 8'd10))") + self.assertRepr(EnumA.const(EnumA.A), "EnumView(EnumA, (const 8'd10))") def test_shape_implicit_wrong_in_concat(self): class EnumA(Enum): @@ -118,3 +120,171 @@ class EnumA(Enum): def test_functional(self): Enum("FOO", ["BAR", "BAZ"]) + + def test_int_enum(self): + class EnumA(IntEnum, shape=signed(4)): + A = 0 + B = -3 + a = Signal(EnumA) + self.assertRepr(a, "(sig a)") + + def test_enum_view(self): + class EnumA(Enum, shape=signed(4)): + A = 0 + B = -3 + class EnumB(Enum, shape=signed(4)): + C = 0 + D = 5 + a = Signal(EnumA) + b = Signal(EnumB) + c = Signal(EnumA) + d = Signal(4) + self.assertIsInstance(a, EnumView) + self.assertIs(a.shape(), EnumA) + self.assertRepr(a, "EnumView(EnumA, (sig a))") + self.assertRepr(a.as_value(), "(sig a)") + self.assertRepr(a.eq(c), "(eq (sig a) (sig c))") + for op in [ + operator.__add__, + operator.__sub__, + operator.__mul__, + operator.__floordiv__, + operator.__mod__, + operator.__lshift__, + operator.__rshift__, + operator.__and__, + operator.__or__, + operator.__xor__, + operator.__lt__, + operator.__le__, + operator.__gt__, + operator.__ge__, + ]: + with self.assertRaises(TypeError): + op(a, a) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, EnumA.A) + for op in [ + operator.__eq__, + operator.__ne__, + ]: + with self.assertRaises(TypeError): + op(a, b) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, EnumB.C) + self.assertRepr(a == c, "(== (sig a) (sig c))") + self.assertRepr(a != c, "(!= (sig a) (sig c))") + self.assertRepr(a == EnumA.B, "(== (sig a) (const 4'sd-3))") + self.assertRepr(EnumA.B == a, "(== (sig a) (const 4'sd-3))") + self.assertRepr(a != EnumA.B, "(!= (sig a) (const 4'sd-3))") + + def test_flag_view(self): + class FlagA(Flag, shape=unsigned(4)): + A = 1 + B = 4 + class FlagB(Flag, shape=unsigned(4)): + C = 1 + D = 2 + a = Signal(FlagA) + b = Signal(FlagB) + c = Signal(FlagA) + d = Signal(4) + self.assertIsInstance(a, FlagView) + self.assertRepr(a, "FlagView(FlagA, (sig a))") + for op in [ + operator.__add__, + operator.__sub__, + operator.__mul__, + operator.__floordiv__, + operator.__mod__, + operator.__lshift__, + operator.__rshift__, + operator.__lt__, + operator.__le__, + operator.__gt__, + operator.__ge__, + ]: + with self.assertRaises(TypeError): + op(a, a) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, FlagA.A) + for op in [ + operator.__eq__, + operator.__ne__, + operator.__and__, + operator.__or__, + operator.__xor__, + ]: + with self.assertRaises(TypeError): + op(a, b) + with self.assertRaises(TypeError): + op(a, d) + with self.assertRaises(TypeError): + op(d, a) + with self.assertRaises(TypeError): + op(a, 3) + with self.assertRaises(TypeError): + op(a, FlagB.C) + self.assertRepr(a == c, "(== (sig a) (sig c))") + self.assertRepr(a != c, "(!= (sig a) (sig c))") + self.assertRepr(a == FlagA.B, "(== (sig a) (const 4'd4))") + self.assertRepr(FlagA.B == a, "(== (sig a) (const 4'd4))") + self.assertRepr(a != FlagA.B, "(!= (sig a) (const 4'd4))") + self.assertRepr(a | c, "FlagView(FlagA, (| (sig a) (sig c)))") + self.assertRepr(a & c, "FlagView(FlagA, (& (sig a) (sig c)))") + self.assertRepr(a ^ c, "FlagView(FlagA, (^ (sig a) (sig c)))") + self.assertRepr(~a, "FlagView(FlagA, (& (~ (sig a)) (const 3'd5)))") + self.assertRepr(a | FlagA.B, "FlagView(FlagA, (| (sig a) (const 4'd4)))") + if sys.version_info >= (3, 11): + class FlagC(Flag, shape=unsigned(4), boundary=py_enum.KEEP): + A = 1 + B = 4 + e = Signal(FlagC) + self.assertRepr(~e, "FlagView(FlagC, (~ (sig e)))") + + def test_enum_view_wrong(self): + class EnumA(Enum, shape=signed(4)): + A = 0 + B = -3 + + a = Signal(2) + with self.assertRaisesRegex(TypeError, + r'^EnumView target must have the same shape as the enum$'): + EnumA(a) + with self.assertRaisesRegex(TypeError, + r'^EnumView target must be a value-castable object, not .*$'): + EnumView(EnumA, "a") + + class EnumB(Enum): + C = 0 + D = 1 + with self.assertRaisesRegex(TypeError, + r'^EnumView type must be an enum with shape, not .*$'): + EnumView(EnumB, 3) + + def test_enum_view_custom(self): + class CustomView(EnumView): + pass + class EnumA(Enum, view_class=CustomView, shape=unsigned(2)): + A = 0 + B = 1 + a = Signal(EnumA) + assert isinstance(a, CustomView)