From 58393a17e9fefee0f7df3c141f3b3f9bb751cfcf Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 14 Sep 2023 16:22:20 -0700 Subject: [PATCH 1/7] [Tuple] Add recursive wiring key ordering check Fixes https://github.com/phanrahan/magma/issues/1316 Ensures that recursive types check tuple key ordering --- magma/tuple.py | 16 +++++++++++++-- magma/wire.py | 5 ++++- tests/test_type/test_tuple.py | 38 +++++++++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/magma/tuple.py b/magma/tuple.py index 2e15c4632..ac7be7126 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -330,6 +330,18 @@ def wire(self, o, debug_info): debug_info=debug_info ) return + + for k in self.keys(): + if not type(self[k]).is_wireable(type(o[k])): + _logger.error( + WiringLog(f"Cannot wire {{}} (type={type(o)}, to " + f" {{}} (type={type(self)})" + f"because the key {k} is not wireable", + o, self), + debug_info=debug_info + ) + return + if self._should_wire_children(o): for self_elem, o_elem in zip(self, o): self_elem = magma_value(self_elem) @@ -585,7 +597,7 @@ def is_wireable(cls, rhs): rhs = magma_type(rhs) if ( not isinstance(rhs, AnonProductKind) or - len(cls.fields) != len(rhs.fields) + list(cls.fields) != list(rhs.fields) ): return False for k, v in cls.field_dict.items(): @@ -597,7 +609,7 @@ def is_bindable(cls, rhs): rhs = magma_type(rhs) if ( not isinstance(rhs, AnonProductKind) or - len(cls.fields) != len(rhs.fields) + list(cls.fields) != list(rhs.fields) ): return False for k, v in cls.field_dict.items(): diff --git a/magma/wire.py b/magma/wire.py index 3faccad93..6f69e696b 100644 --- a/magma/wire.py +++ b/magma/wire.py @@ -54,8 +54,11 @@ def wire(o, i, debug_info=None): i_T, o_T = type(i), type(o) if not i_T.is_wireable(o_T): + # Escape curly braces in type string for format call. + o_T_str = str(o_T).replace("{", "{{").replace("}", "}}") + i_T_str = str(i_T).replace("{", "{{").replace("}", "}}") _logger.error( - WiringLog(f"Cannot wire {{}} ({o_T}) to {{}} ({i_T})", + WiringLog(f"Cannot wire {{}} ({o_T_str}) to {{}} ({i_T_str})", o, i), debug_info=debug_info ) diff --git a/tests/test_type/test_tuple.py b/tests/test_type/test_tuple.py index 45ea260f5..a41584c8d 100644 --- a/tests/test_type/test_tuple.py +++ b/tests/test_type/test_tuple.py @@ -366,11 +366,45 @@ def test_tuple_key_ordering(caplog): class tuple_key_ordering(m.Circuit): io = m.IO(I=m.In(T0), O=m.Out(T1)) + io.O.wire(io.I) + + msg = """\ +\033[1mtests/test_type/test_tuple.py:369\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(x=Out(Bit),y=Out(Bits[8])), keys=['y', 'x']) to tuple_key_ordering.O (type=Tuple(y=In(Bits[8]),x=In(Bit)), keys=['x', 'y']) because the tuples do not have the same keys +>> io.O.wire(io.I)\ +""" + + assert has_error(caplog, msg) + + +def test_tuple_key_ordering_recursive(caplog): + T0 = m.AnonProduct[{"x": m.Bit, "y": m.Bits[8]}] + T1 = m.AnonProduct[{"y": m.Bits[8], "x": m.Bit}] + + class tuple_key_ordering(m.Circuit): + io = m.IO(I=m.In(m.Array[2, T0]), O=m.Out(m.Array[2, T1])) io.O @= io.I msg = """\ -\033[1mtests/test_type/test_tuple.py:368\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(x=Out(Bit),y=Out(Bits[8])), keys=['y', 'x']) to tuple_key_ordering.O (type=Tuple(y=In(Bits[8]),x=In(Bit)), keys=['x', 'y']) because the tuples do not have the same keys +\033[1mtests/test_type/test_tuple.py:385\033[0m: Cannot wire tuple_key_ordering.I (Array[(2, AnonProduct[{'x': Bit[Out], 'y': Bits[(8, Out(Bit))]}])]) to tuple_key_ordering.O (Array[(2, AnonProduct[{'y': Bits[(8, In(Bit))], 'x': Bit[In]}])]) >> io.O @= io.I\ """ - assert has_error(caplog, "") + assert has_error(caplog, msg) + + +def test_tuple_key_ordering_recursive_2(caplog): + T0 = m.AnonProduct[{"x": m.Bit, "y": m.Bits[8]}] + T1 = m.AnonProduct[{"y": m.Bits[8], "x": m.Bit}] + T2 = m.AnonProduct[{"y": T0, "x": m.Bit}] + T3 = m.AnonProduct[{"y": T1, "x": m.Bit}] + + class tuple_key_ordering(m.Circuit): + io = m.IO(I=m.In(T2), O=m.Out(T3)) + io.O.wire(io.I) + + msg = """\ +\033[1mtests/test_type/test_tuple.py:403\033[0m: Cannot wire tuple_key_ordering.I (type=Tuple(y=Tuple(x=Out(Bit),y=Out(Bits[8])),x=Out(Bit)), to tuple_key_ordering.O (type=Tuple(y=Tuple(y=In(Bits[8]),x=In(Bit)),x=In(Bit)))because the key y is not wireable +>> io.O.wire(io.I)\ +""" + + assert has_error(caplog, msg) From 1256d952fabaf6f06ec6754fb81b88fe4b21c491 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 14 Sep 2023 16:36:14 -0700 Subject: [PATCH 2/7] Fix field logic bug --- magma/tuple.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/magma/tuple.py b/magma/tuple.py index ac7be7126..0fbd67b49 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -163,7 +163,8 @@ def is_wireable(cls, rhs): if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields): return False for idx, T in enumerate(cls.fields): - if not T.is_wireable(rhs[idx]): + T = magma_type(T) + if not T.is_wireable(magma_type(rhs[idx])): return False return True @@ -172,7 +173,8 @@ def is_bindable(cls, rhs): if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields): return False for idx, T in enumerate(cls.fields): - if not T.is_bindable(rhs[idx]): + T = magma_type(T) + if not T.is_bindable(magma_type(rhs[idx])): return False return True @@ -597,11 +599,12 @@ def is_wireable(cls, rhs): rhs = magma_type(rhs) if ( not isinstance(rhs, AnonProductKind) or - list(cls.fields) != list(rhs.fields) + list(cls.field_dict.keys()) != list(rhs.field_dict.keys()) ): return False for k, v in cls.field_dict.items(): - if k not in rhs.field_dict or not v.is_wireable(rhs.field_dict[k]): + v = magma_type(v) + if not v.is_wireable(magma_type(rhs.field_dict[k])): return False return True @@ -609,11 +612,12 @@ def is_bindable(cls, rhs): rhs = magma_type(rhs) if ( not isinstance(rhs, AnonProductKind) or - list(cls.fields) != list(rhs.fields) + list(cls.field_dict.keys()) != list(rhs.field_dict.keys()) ): return False for k, v in cls.field_dict.items(): - if k not in rhs.field_dict or not v.is_bindable(rhs.field_dict[k]): + v = magma_type(v) + if not v.is_bindable(magma_type(rhs.field_dict[k])): return False return True From f1e96db0157e652f10e2b32dc75788a51b69962c Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 14 Sep 2023 16:49:17 -0700 Subject: [PATCH 3/7] Fix elab issue --- magma/tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/magma/tuple.py b/magma/tuple.py index 0fbd67b49..08defd78a 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -334,7 +334,7 @@ def wire(self, o, debug_info): return for k in self.keys(): - if not type(self[k]).is_wireable(type(o[k])): + if not type(self)[k].is_wireable(type(o)[k]): _logger.error( WiringLog(f"Cannot wire {{}} (type={type(o)}, to " f" {{}} (type={type(self)})" From f2d338840f026954248306a86efaa465cd2d4eda Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 14 Sep 2023 16:56:24 -0700 Subject: [PATCH 4/7] Fix index logic --- magma/tuple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/magma/tuple.py b/magma/tuple.py index 08defd78a..c70ba409a 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -333,8 +333,8 @@ def wire(self, o, debug_info): ) return - for k in self.keys(): - if not type(self)[k].is_wireable(type(o)[k]): + for i, k in enumerate(self.keys()): + if not type(self)[i].is_wireable(type(o)[i]): _logger.error( WiringLog(f"Cannot wire {{}} (type={type(o)}, to " f" {{}} (type={type(self)})" From bb93915afdb23a2c47adce7f705ca9b9ab5bb992 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 14 Sep 2023 16:58:59 -0700 Subject: [PATCH 5/7] Use .fields --- magma/tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/magma/tuple.py b/magma/tuple.py index c70ba409a..0c21176c8 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -334,7 +334,7 @@ def wire(self, o, debug_info): return for i, k in enumerate(self.keys()): - if not type(self)[i].is_wireable(type(o)[i]): + if not type(self).fields[i].is_wireable(type(o).fields[i]): _logger.error( WiringLog(f"Cannot wire {{}} (type={type(o)}, to " f" {{}} (type={type(self)})" From 6227eeefcb86c43a8e8ecbf48d64fc77ab864a4c Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 21 Sep 2023 12:56:34 -0700 Subject: [PATCH 6/7] Address review comments --- magma/tuple.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/magma/tuple.py b/magma/tuple.py index 0c21176c8..8a36f66cb 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -163,8 +163,7 @@ def is_wireable(cls, rhs): if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields): return False for idx, T in enumerate(cls.fields): - T = magma_type(T) - if not T.is_wireable(magma_type(rhs[idx])): + if not magma_type(T).is_wireable(rhs[idx]): return False return True @@ -336,10 +335,13 @@ def wire(self, o, debug_info): for i, k in enumerate(self.keys()): if not type(self).fields[i].is_wireable(type(o).fields[i]): _logger.error( - WiringLog(f"Cannot wire {{}} (type={type(o)}, to " - f" {{}} (type={type(self)})" - f"because the key {k} is not wireable", - o, self), + WiringLog( + f"Cannot wire {{}} (type={type(o)}, to " + f" {{}} (type={type(self)})" + f"because the key {k} is not wireable", + o, + self + ), debug_info=debug_info ) return @@ -603,8 +605,7 @@ def is_wireable(cls, rhs): ): return False for k, v in cls.field_dict.items(): - v = magma_type(v) - if not v.is_wireable(magma_type(rhs.field_dict[k])): + if not magma_type(v).is_wireable(rhs.field_dict[k]): return False return True From 0600fd7a3e0fd8c157993040aca279a97895ffcf Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 21 Sep 2023 14:55:06 -0700 Subject: [PATCH 7/7] Update is_bindable --- magma/tuple.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/magma/tuple.py b/magma/tuple.py index 8a36f66cb..7f901f90b 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -172,8 +172,7 @@ def is_bindable(cls, rhs): if not isinstance(rhs, TupleKind) or len(cls.fields) != len(rhs.fields): return False for idx, T in enumerate(cls.fields): - T = magma_type(T) - if not T.is_bindable(magma_type(rhs[idx])): + if not magma_type(T).is_bindable(rhs[idx]): return False return True @@ -617,8 +616,7 @@ def is_bindable(cls, rhs): ): return False for k, v in cls.field_dict.items(): - v = magma_type(v) - if not v.is_bindable(magma_type(rhs.field_dict[k])): + if not magma_type(v).is_bindable(rhs.field_dict[k]): return False return True