Skip to content

Commit

Permalink
sim: group signal traces according to their function.
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored and whitequark committed May 22, 2024
1 parent 89eae72 commit 51e0262
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 28 deletions.
20 changes: 14 additions & 6 deletions amaranth/sim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,25 @@ def write_vcd(self, vcd_file, gtkw_file=None, *, traces=(), fs_per_delta=0):
file.close()
raise ValueError("Cannot start writing waveforms after advancing simulation time")

for trace in traces:
if isinstance(trace, ValueLike):
trace_cast = Value.cast(trace)
def traverse_traces(traces):
if isinstance(traces, ValueLike):
trace_cast = Value.cast(traces)
if isinstance(trace_cast, MemoryData._Row):
continue
return
for trace_signal in trace_cast._rhs_signals():
if trace_signal.name == "":
if trace_signal is trace:
if trace_signal is traces:
raise TypeError("Cannot trace signal with private name")
else:
raise TypeError(f"Cannot trace signal with private name (within {trace!r})")
raise TypeError(f"Cannot trace signal with private name (within {traces!r})")
elif isinstance(traces, (list, tuple)):
for trace in traces:
traverse_traces(trace)
elif isinstance(traces, dict):
for trace in traces.values():
traverse_traces(trace)

traverse_traces(traces)

return self._engine.write_vcd(vcd_file=vcd_file, gtkw_file=gtkw_file,
traces=traces, fs_per_delta=fs_per_delta)
72 changes: 51 additions & 21 deletions amaranth/sim/pysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import enum as py_enum

from ..hdl import *
from ..hdl._mem import MemoryInstance
from ..hdl._ast import SignalDict
from ..lib import data, wiring
from ._base import *
from ._async import *
from ._pyeval import eval_format, eval_value, eval_assign
Expand Down Expand Up @@ -49,7 +51,7 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
self.gtkw_file = gtkw_file
self.gtkw_save = gtkw_file and vcd.gtkw.GTKWSave(self.gtkw_file)

self.traces = []
self.traces = traces

signal_names = SignalDict()
memories = {}
Expand All @@ -64,9 +66,9 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per

trace_names = SignalDict()
assigned_names = set()
for trace in traces:
if isinstance(trace, ValueLike):
trace = Value.cast(trace)
def traverse_traces(traces):
if isinstance(traces, ValueLike):
trace = Value.cast(traces)
if isinstance(trace, MemoryData._Row):
memory = trace._memory
if not memory in memories:
Expand All @@ -77,7 +79,6 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
assert name not in assigned_names
memories[memory] = ("bench", name)
assigned_names.add(name)
self.traces.append(trace)
else:
for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names:
Expand All @@ -88,19 +89,27 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
assert name not in assigned_names
trace_names[trace_signal] = {("bench", name)}
assigned_names.add(name)
self.traces.append(trace_signal)
elif isinstance(trace, MemoryData):
if not trace in memories:
if trace.name not in assigned_names:
name = trace.name
elif isinstance(traces, MemoryData):
if not traces in memories:
if traces.name not in assigned_names:
name = traces.name
else:
name = f"{trace.name}${len(assigned_names)}"
name = f"{traces.name}${len(assigned_names)}"
assert name not in assigned_names
memories[trace] = ("bench", name)
memories[traces] = ("bench", name)
assigned_names.add(name)
self.traces.append(trace)
elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature):
for name in traces.signature.members:
traverse_traces(getattr(traces, name))
elif isinstance(traces, list) or isinstance(traces, tuple):
for trace in traces:
traverse_traces(trace)
elif isinstance(traces, dict):
for trace in traces.values():
traverse_traces(trace)
else:
raise TypeError(f"{trace!r} is not a traceable object")
raise TypeError(f"{traces!r} is not a traceable object")
traverse_traces(traces)

if self.vcd_writer is None:
return
Expand Down Expand Up @@ -277,19 +286,40 @@ def close(self, timestamp):
self.gtkw_save.dumpfile_size(self.vcd_file.tell())

self.gtkw_save.treeopen("top")
for trace in self.traces:
if isinstance(trace, Signal):
for name in self.gtkw_signal_names[trace]:

def traverse_traces(traces):
if isinstance(traces, Signal):
for name in self.gtkw_signal_names[traces]:
self.gtkw_save.trace(name)
elif isinstance(trace, MemoryData):
for row_names in self.gtkw_memory_names[trace]:
elif isinstance(traces, data.View):
with self.gtkw_save.group("view"):
trace = Value.cast(traces)
for trace_signal in trace._rhs_signals():
for name in self.gtkw_signal_names[trace_signal]:
self.gtkw_save.trace(name)
elif isinstance(traces, ValueLike):
traverse_traces(Value.cast(traces))
elif isinstance(traces, MemoryData):
for row_names in self.gtkw_memory_names[traces]:
for name in row_names:
self.gtkw_save.trace(name)
elif isinstance(trace, MemoryData._Row):
for name in self.gtkw_memory_names[trace._memory][trace._index]:
elif isinstance(traces, MemoryData._Row):
for name in self.gtkw_memory_names[traces._memory][traces._index]:
self.gtkw_save.trace(name)
elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature):
with self.gtkw_save.group("interface"):
for _, _, member in traces.signature.flatten(traces):
traverse_traces(member)
elif isinstance(traces, list) or isinstance(traces, tuple):
for trace in traces:
traverse_traces(trace)
elif isinstance(traces, dict):
for name, trace in traces.items():
with self.gtkw_save.group(name):
traverse_traces(trace)
else:
assert False # :nocov:
traverse_traces(self.traces)

if self.close_vcd:
self.vcd_file.close()
Expand Down
45 changes: 44 additions & 1 deletion tests/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from amaranth.sim import *
from amaranth.sim._pyeval import eval_format
from amaranth.lib.memory import Memory
from amaranth.lib import enum, data
from amaranth.lib import enum, data, wiring

from .utils import *
from amaranth._utils import _ignore_deprecated
Expand Down Expand Up @@ -1393,6 +1393,49 @@ def testbench():
sim.add_testbench(testbench)


class SimulatorTracesTestCase(FHDLTestCase):
def assertDef(self, traces, flat_traces):
frag = Fragment()

def process():
yield Delay(1e-6)

sim = Simulator(frag)
sim.add_testbench(process)
with sim.write_vcd("test.vcd", "test.gtkw", traces=traces):
sim.run()

def test_signal(self):
a = Signal()
self.assertDef(a, [a])

def test_list(self):
a = Signal()
self.assertDef([a], [a])

def test_tuple(self):
a = Signal()
self.assertDef((a,), [a])

def test_dict(self):
a = Signal()
self.assertDef({"a": a}, [a])

def test_struct_view(self):
a = Signal(data.StructLayout({"a": 1, "b": 3}))
self.assertDef(a, [a])

def test_interface(self):
sig = wiring.Signature({
"a": wiring.In(1),
"b": wiring.Out(3),
"c": wiring.Out(2).array(4),
"d": wiring.In(wiring.Signature({"e": wiring.In(5)}))
})
a = sig.create()
self.assertDef(a, [a])


class SimulatorRegressionTestCase(FHDLTestCase):
def test_bug_325(self):
dut = Module()
Expand Down

0 comments on commit 51e0262

Please sign in to comment.