diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index d99cd7243..3f87e9e44 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1946,6 +1946,57 @@ def __call__(cls, shape=None, src_loc_at=0, **kwargs): return signal +# also used for MemoryData.Init +def _get_init_value(init, shape, what="signal"): + orig_init = init + orig_shape = shape + shape = Shape.cast(shape) + if isinstance(orig_shape, ShapeCastable): + try: + init = Const.cast(orig_shape.const(init)) + except Exception: + raise TypeError(f"Initial value must be a constant initializer of {orig_shape!r}") + if init.shape() != Shape.cast(shape): + raise ValueError(f"Constant returned by {orig_shape!r}.const() must have the shape " + f"that it casts to, {shape!r}, and not {init.shape()!r}") + return init.value + else: + if init is None: + init = 0 + try: + init = Const.cast(init) + except TypeError: + raise TypeError("Initial value must be a constant-castable expression, not {!r}" + .format(orig_init)) + # Avoid false positives for all-zeroes and all-ones + if orig_init is not None and not (isinstance(orig_init, int) and orig_init in (0, -1)): + if init.shape().signed and not shape.signed: + warnings.warn( + message=f"Initial value {orig_init!r} is signed, " + f"but the {what} shape is {shape!r}", + category=SyntaxWarning, + stacklevel=2) + elif (init.shape().width > shape.width or + init.shape().width == shape.width and + shape.signed and not init.shape().signed): + warnings.warn( + message=f"Initial value {orig_init!r} will be truncated to " + f"the {what} shape {shape!r}", + category=SyntaxWarning, + stacklevel=2) + + if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape: + if orig_init == orig_shape.stop: + raise SyntaxError( + f"Initial value {orig_init!r} equals the non-inclusive end of the {what} " + f"shape {orig_shape!r}; this is likely an off-by-one error") + else: + raise SyntaxError( + f"Initial value {orig_init!r} is not within the {what} shape {orig_shape!r}") + + return Const(init.value, shape).value + + @final class Signal(Value, DUID, metaclass=_SignalMeta): """A varying integer value. @@ -2016,54 +2067,9 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F DeprecationWarning, stacklevel=2) init = reset - orig_init = init - if isinstance(orig_shape, ShapeCastable): - try: - init = Const.cast(orig_shape.const(init)) - except Exception: - raise TypeError("Initial value must be a constant initializer of {!r}" - .format(orig_shape)) - if init.shape() != Shape.cast(orig_shape): - raise ValueError("Constant returned by {!r}.const() must have the shape that " - "it casts to, {!r}, and not {!r}" - .format(orig_shape, Shape.cast(orig_shape), - init.shape())) - else: - if init is None: - init = 0 - try: - init = Const.cast(init) - except TypeError: - raise TypeError("Initial value must be a constant-castable expression, not {!r}" - .format(orig_init)) - # Avoid false positives for all-zeroes and all-ones - if orig_init is not None and not (isinstance(orig_init, int) and orig_init in (0, -1)): - if init.shape().signed and not self._signed: - warnings.warn( - message="Initial value {!r} is signed, but the signal shape is {!r}" - .format(orig_init, shape), - category=SyntaxWarning, - stacklevel=2) - elif (init.shape().width > self._width or - init.shape().width == self._width and - self._signed and not init.shape().signed): - warnings.warn( - message="Initial value {!r} will be truncated to the signal shape {!r}" - .format(orig_init, shape), - category=SyntaxWarning, - stacklevel=2) - self._init = Const(init.value, shape).value + self._init = _get_init_value(init, unsigned(1) if orig_shape is None else orig_shape) self._reset_less = bool(reset_less) - if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape: - if orig_init == orig_shape.stop: - raise SyntaxError( - f"Initial value {orig_init!r} equals the non-inclusive end of the signal " - f"shape {orig_shape!r}; this is likely an off-by-one error") - else: - raise SyntaxError( - f"Initial value {orig_init!r} is not within the signal shape {orig_shape!r}") - self._attrs = OrderedDict(() if attrs is None else attrs) if isinstance(orig_shape, ShapeCastable): diff --git a/amaranth/hdl/_mem.py b/amaranth/hdl/_mem.py index 0554f738f..ab348551f 100644 --- a/amaranth/hdl/_mem.py +++ b/amaranth/hdl/_mem.py @@ -3,6 +3,7 @@ from .. import tracer from ._ast import * +from ._ast import _get_init_value from ._ir import Fragment, AlreadyElaborated from ..utils import ceil_log2 from .._utils import final @@ -105,10 +106,11 @@ def __setitem__(self, index, value): for actual_index, actual_value in zip(indices, value): self[actual_index] = actual_value else: + raw = _get_init_value(value, self._shape, "memory") if isinstance(self._shape, ShapeCastable): - self._raw[index] = Const.cast(Const(value, self._shape)).value + self._raw[index] = raw else: - value = operator.index(value) + value = raw # self._raw[index] assigned by the following line self._elems[index] = value diff --git a/tests/test_hdl_mem.py b/tests/test_hdl_mem.py index 504e5984d..fb4596566 100644 --- a/tests/test_hdl_mem.py +++ b/tests/test_hdl_mem.py @@ -31,3 +31,29 @@ def test_row_elab(self): with self.assertRaisesRegex(ValueError, r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"): m.d.comb += data[0].eq(1) + + +class InitTestCase(FHDLTestCase): + def test_ones(self): + init = MemoryData.Init([-1, 12], shape=8, depth=2) + self.assertEqual(list(init), [0xff, 12]) + init = MemoryData.Init([-1, -12], shape=signed(8), depth=2) + self.assertEqual(list(init), [-1, -12]) + + def test_trunc(self): + with self.assertWarnsRegex(SyntaxWarning, + r"^Initial value -2 is signed, but the memory shape is unsigned\(8\)$"): + init = MemoryData.Init([-2, 12], shape=8, depth=2) + self.assertEqual(list(init), [0xfe, 12]) + with self.assertWarnsRegex(SyntaxWarning, + r"^Initial value 258 will be truncated to the memory shape unsigned\(8\)$"): + init = MemoryData.Init([258, 129], shape=8, depth=2) + self.assertEqual(list(init), [2, 129]) + with self.assertWarnsRegex(SyntaxWarning, + r"^Initial value 128 will be truncated to the memory shape signed\(8\)$"): + init = MemoryData.Init([128], shape=signed(8), depth=1) + self.assertEqual(list(init), [-128]) + with self.assertWarnsRegex(SyntaxWarning, + r"^Initial value -129 will be truncated to the memory shape signed\(8\)$"): + init = MemoryData.Init([-129], shape=signed(8), depth=1) + self.assertEqual(list(init), [127]) diff --git a/tests/test_lib_memory.py b/tests/test_lib_memory.py index 3700613f3..ba828a8b3 100644 --- a/tests/test_lib_memory.py +++ b/tests/test_lib_memory.py @@ -329,7 +329,7 @@ def test_constructor_wrong(self): memory.Memory(shape="a", depth=3, init=[]) with self.assertRaisesRegex(TypeError, (r"^Memory initialization value at address 1: " - r"'str' object cannot be interpreted as an integer$")): + r"Initial value must be a constant-castable expression, not '0'$")): memory.Memory(shape=8, depth=4, init=[1, "0"]) with self.assertRaisesRegex(ValueError, r"^Either 'data' or 'shape' needs to be given$"): @@ -373,7 +373,7 @@ def test_init_set_shapecastable(self): def test_init_set_wrong(self): m = memory.Memory(shape=8, depth=4, init=[]) with self.assertRaisesRegex(TypeError, - r"^'str' object cannot be interpreted as an integer$"): + r"^Initial value must be a constant-castable expression, not 'a'$"): m.init[0] = "a" m = memory.Memory(shape=MyStruct, depth=4, init=[]) # underlying TypeError message differs between PyPy and CPython