Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Types] Add recursive wiring key ordering check #1317

Merged
merged 7 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions magma/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def is_wireable(cls, rhs):
if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields):
return False
for idx, T in enumerate(cls.fields):
if not T.is_wireable(rhs[idx]):
if not magma_type(T).is_wireable(rhs[idx]):
return False
return True

Expand All @@ -172,7 +172,7 @@ def is_bindable(cls, rhs):
if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields):
return False
for idx, T in enumerate(cls.fields):
if not T.is_bindable(rhs[idx]):
if not magma_type(T).is_bindable(rhs[idx]):
return False
return True

Expand Down Expand Up @@ -330,6 +330,21 @@ def wire(self, o, debug_info):
debug_info=debug_info
)
return

for i, k in enumerate(self.keys()):
if not type(self).fields[i].is_wireable(type(o).fields[i]):
_logger.error(
WiringLog(
f"Cannot wire {{}} (type={type(o)}, to "
f" {{}} (type={type(self)})"
f"because the key {k} is not wireable",
o,
self
),
debug_info=debug_info
)
return

if self._should_wire_children(o):
for self_elem, o_elem in zip(self, o):
self_elem = magma_value(self_elem)
Expand Down Expand Up @@ -585,23 +600,23 @@ def is_wireable(cls, rhs):
rhs = magma_type(rhs)
if (
not isinstance(rhs, AnonProductKind) or
len(cls.fields) != len(rhs.fields)
list(cls.field_dict.keys()) != list(rhs.field_dict.keys())
):
return False
for k, v in cls.field_dict.items():
if k not in rhs.field_dict or not v.is_wireable(rhs.field_dict[k]):
if not magma_type(v).is_wireable(rhs.field_dict[k]):
return False
return True

def is_bindable(cls, rhs):
rhs = magma_type(rhs)
if (
not isinstance(rhs, AnonProductKind) or
len(cls.fields) != len(rhs.fields)
list(cls.field_dict.keys()) != list(rhs.field_dict.keys())
):
return False
for k, v in cls.field_dict.items():
if k not in rhs.field_dict or not v.is_bindable(rhs.field_dict[k]):
if not magma_type(v).is_bindable(rhs.field_dict[k]):
return False
return True

Expand Down
5 changes: 4 additions & 1 deletion magma/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def wire(o, i, debug_info=None):

i_T, o_T = type(i), type(o)
if not i_T.is_wireable(o_T):
# Escape curly braces in type string for format call.
o_T_str = str(o_T).replace("{", "{{").replace("}", "}}")
i_T_str = str(i_T).replace("{", "{{").replace("}", "}}")
_logger.error(
WiringLog(f"Cannot wire {{}} ({o_T}) to {{}} ({i_T})",
WiringLog(f"Cannot wire {{}} ({o_T_str}) to {{}} ({i_T_str})",
o, i),
debug_info=debug_info
)
Expand Down
38 changes: 36 additions & 2 deletions tests/test_type/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,45 @@ def test_tuple_key_ordering(caplog):

class tuple_key_ordering(m.Circuit):
io = m.IO(I=m.In(T0), O=m.Out(T1))
io.O.wire(io.I)

msg = """\
\033[1mtests/test_type/test_tuple.py:369\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(x=Out(Bit),y=Out(Bits[8])), keys=['y', 'x']) to tuple_key_ordering.O (type=Tuple(y=In(Bits[8]),x=In(Bit)), keys=['x', 'y']) because the tuples do not have the same keys
>> io.O.wire(io.I)\
"""

assert has_error(caplog, msg)


def test_tuple_key_ordering_recursive(caplog):
T0 = m.AnonProduct[{"x": m.Bit, "y": m.Bits[8]}]
T1 = m.AnonProduct[{"y": m.Bits[8], "x": m.Bit}]

class tuple_key_ordering(m.Circuit):
io = m.IO(I=m.In(m.Array[2, T0]), O=m.Out(m.Array[2, T1]))
io.O @= io.I

msg = """\
\033[1mtests/test_type/test_tuple.py:368\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(x=Out(Bit),y=Out(Bits[8])), keys=['y', 'x']) to tuple_key_ordering.O (type=Tuple(y=In(Bits[8]),x=In(Bit)), keys=['x', 'y']) because the tuples do not have the same keys
\033[1mtests/test_type/test_tuple.py:385\033[0m: Cannot wire tuple_key_ordering.I (Array[(2, AnonProduct[{'x': Bit[Out], 'y': Bits[(8, Out(Bit))]}])]) to tuple_key_ordering.O (Array[(2, AnonProduct[{'y': Bits[(8, In(Bit))], 'x': Bit[In]}])])
>> io.O @= io.I\
"""

assert has_error(caplog, "")
assert has_error(caplog, msg)


def test_tuple_key_ordering_recursive_2(caplog):
T0 = m.AnonProduct[{"x": m.Bit, "y": m.Bits[8]}]
T1 = m.AnonProduct[{"y": m.Bits[8], "x": m.Bit}]
T2 = m.AnonProduct[{"y": T0, "x": m.Bit}]
T3 = m.AnonProduct[{"y": T1, "x": m.Bit}]

class tuple_key_ordering(m.Circuit):
io = m.IO(I=m.In(T2), O=m.Out(T3))
io.O.wire(io.I)

msg = """\
\033[1mtests/test_type/test_tuple.py:403\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(y=Tuple(x=Out(Bit),y=Out(Bits[8])),x=Out(Bit)), to tuple_key_ordering.O (type=Tuple(y=Tuple(y=In(Bits[8]),x=In(Bit)),x=In(Bit)))because the key y is not wireable
>> io.O.wire(io.I)\
"""

assert has_error(caplog, msg)
Loading