Skip to content

Commit

Permalink
hdl.mem: mask initial value to shape.
Browse files Browse the repository at this point in the history
Fixes #1492.
  • Loading branch information
wanda-phi authored and whitequark committed Aug 27, 2024
1 parent fff8f0b commit c03e450
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 50 deletions.
98 changes: 52 additions & 46 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,57 @@ def __call__(cls, shape=None, src_loc_at=0, **kwargs):
return signal


# also used for MemoryData.Init
def _get_init_value(init, shape, what="signal"):
orig_init = init
orig_shape = shape
shape = Shape.cast(shape)
if isinstance(orig_shape, ShapeCastable):
try:
init = Const.cast(orig_shape.const(init))
except Exception:
raise TypeError(f"Initial value must be a constant initializer of {orig_shape!r}")
if init.shape() != Shape.cast(shape):
raise ValueError(f"Constant returned by {orig_shape!r}.const() must have the shape "
f"that it casts to, {shape!r}, and not {init.shape()!r}")
return init.value
else:
if init is None:
init = 0
try:
init = Const.cast(init)
except TypeError:
raise TypeError("Initial value must be a constant-castable expression, not {!r}"
.format(orig_init))
# Avoid false positives for all-zeroes and all-ones
if orig_init is not None and not (isinstance(orig_init, int) and orig_init in (0, -1)):
if init.shape().signed and not shape.signed:
warnings.warn(
message=f"Initial value {orig_init!r} is signed, "
f"but the {what} shape is {shape!r}",
category=SyntaxWarning,
stacklevel=2)
elif (init.shape().width > shape.width or
init.shape().width == shape.width and
shape.signed and not init.shape().signed):
warnings.warn(
message=f"Initial value {orig_init!r} will be truncated to "
f"the {what} shape {shape!r}",
category=SyntaxWarning,
stacklevel=2)

if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape:
if orig_init == orig_shape.stop:
raise SyntaxError(
f"Initial value {orig_init!r} equals the non-inclusive end of the {what} "
f"shape {orig_shape!r}; this is likely an off-by-one error")
else:
raise SyntaxError(
f"Initial value {orig_init!r} is not within the {what} shape {orig_shape!r}")

return Const(init.value, shape).value


@final
class Signal(Value, DUID, metaclass=_SignalMeta):
"""A varying integer value.
Expand Down Expand Up @@ -2016,54 +2067,9 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
DeprecationWarning, stacklevel=2)
init = reset

orig_init = init
if isinstance(orig_shape, ShapeCastable):
try:
init = Const.cast(orig_shape.const(init))
except Exception:
raise TypeError("Initial value must be a constant initializer of {!r}"
.format(orig_shape))
if init.shape() != Shape.cast(orig_shape):
raise ValueError("Constant returned by {!r}.const() must have the shape that "
"it casts to, {!r}, and not {!r}"
.format(orig_shape, Shape.cast(orig_shape),
init.shape()))
else:
if init is None:
init = 0
try:
init = Const.cast(init)
except TypeError:
raise TypeError("Initial value must be a constant-castable expression, not {!r}"
.format(orig_init))
# Avoid false positives for all-zeroes and all-ones
if orig_init is not None and not (isinstance(orig_init, int) and orig_init in (0, -1)):
if init.shape().signed and not self._signed:
warnings.warn(
message="Initial value {!r} is signed, but the signal shape is {!r}"
.format(orig_init, shape),
category=SyntaxWarning,
stacklevel=2)
elif (init.shape().width > self._width or
init.shape().width == self._width and
self._signed and not init.shape().signed):
warnings.warn(
message="Initial value {!r} will be truncated to the signal shape {!r}"
.format(orig_init, shape),
category=SyntaxWarning,
stacklevel=2)
self._init = Const(init.value, shape).value
self._init = _get_init_value(init, unsigned(1) if orig_shape is None else orig_shape)
self._reset_less = bool(reset_less)

if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape:
if orig_init == orig_shape.stop:
raise SyntaxError(
f"Initial value {orig_init!r} equals the non-inclusive end of the signal "
f"shape {orig_shape!r}; this is likely an off-by-one error")
else:
raise SyntaxError(
f"Initial value {orig_init!r} is not within the signal shape {orig_shape!r}")

self._attrs = OrderedDict(() if attrs is None else attrs)

if isinstance(orig_shape, ShapeCastable):
Expand Down
6 changes: 4 additions & 2 deletions amaranth/hdl/_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .. import tracer
from ._ast import *
from ._ast import _get_init_value
from ._ir import Fragment, AlreadyElaborated
from ..utils import ceil_log2
from .._utils import final
Expand Down Expand Up @@ -105,10 +106,11 @@ def __setitem__(self, index, value):
for actual_index, actual_value in zip(indices, value):
self[actual_index] = actual_value
else:
raw = _get_init_value(value, self._shape, "memory")
if isinstance(self._shape, ShapeCastable):
self._raw[index] = Const.cast(Const(value, self._shape)).value
self._raw[index] = raw
else:
value = operator.index(value)
value = raw
# self._raw[index] assigned by the following line
self._elems[index] = value

Expand Down
26 changes: 26 additions & 0 deletions tests/test_hdl_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,29 @@ def test_row_elab(self):
with self.assertRaisesRegex(ValueError,
r"^Value \(memory-row \(memory-data data\) 0\) can only be used in simulator processes$"):
m.d.comb += data[0].eq(1)


class InitTestCase(FHDLTestCase):
def test_ones(self):
init = MemoryData.Init([-1, 12], shape=8, depth=2)
self.assertEqual(list(init), [0xff, 12])
init = MemoryData.Init([-1, -12], shape=signed(8), depth=2)
self.assertEqual(list(init), [-1, -12])

def test_trunc(self):
with self.assertWarnsRegex(SyntaxWarning,
r"^Initial value -2 is signed, but the memory shape is unsigned\(8\)$"):
init = MemoryData.Init([-2, 12], shape=8, depth=2)
self.assertEqual(list(init), [0xfe, 12])
with self.assertWarnsRegex(SyntaxWarning,
r"^Initial value 258 will be truncated to the memory shape unsigned\(8\)$"):
init = MemoryData.Init([258, 129], shape=8, depth=2)
self.assertEqual(list(init), [2, 129])
with self.assertWarnsRegex(SyntaxWarning,
r"^Initial value 128 will be truncated to the memory shape signed\(8\)$"):
init = MemoryData.Init([128], shape=signed(8), depth=1)
self.assertEqual(list(init), [-128])
with self.assertWarnsRegex(SyntaxWarning,
r"^Initial value -129 will be truncated to the memory shape signed\(8\)$"):
init = MemoryData.Init([-129], shape=signed(8), depth=1)
self.assertEqual(list(init), [127])
4 changes: 2 additions & 2 deletions tests/test_lib_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_constructor_wrong(self):
memory.Memory(shape="a", depth=3, init=[])
with self.assertRaisesRegex(TypeError,
(r"^Memory initialization value at address 1: "
r"'str' object cannot be interpreted as an integer$")):
r"Initial value must be a constant-castable expression, not '0'$")):
memory.Memory(shape=8, depth=4, init=[1, "0"])
with self.assertRaisesRegex(ValueError,
r"^Either 'data' or 'shape' needs to be given$"):
Expand Down Expand Up @@ -373,7 +373,7 @@ def test_init_set_shapecastable(self):
def test_init_set_wrong(self):
m = memory.Memory(shape=8, depth=4, init=[])
with self.assertRaisesRegex(TypeError,
r"^'str' object cannot be interpreted as an integer$"):
r"^Initial value must be a constant-castable expression, not 'a'$"):
m.init[0] = "a"
m = memory.Memory(shape=MyStruct, depth=4, init=[])
# underlying TypeError message differs between PyPy and CPython
Expand Down

0 comments on commit c03e450

Please sign in to comment.