Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional interleaving / deinterleaving in Axis #62

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ To use these modules, import the one you need and connect it to the DUT:

The first argument to the constructor accepts an `AxiStreamBus` object. This object is a container for the interface signals and includes class methods to automate connections.

To allow `AxiStreamSource` to interleave data set the interleave parameter a dictionary containing `tid` and / or `tdest`. THe maximum interleave depth can also be set with `max_interleave_depth`. By default is is unbound.

To send data into a design with an `AxiStreamSource`, call `send()`/`send_nowait()` or `write()`/`write_nowait()`. Accepted data types are iterables or `AxiStreamFrame` objects. Optionally, call `wait()` to wait for the transmit operation to complete. Example:

await axis_source.send(b'test data')
Expand All @@ -246,6 +248,11 @@ To receive data with an `AxiStreamSink` or `AxiStreamMonitor`, call `recv()`/`re

data = await axis_sink.recv()

To deinterleave receive data the `interleave` parameter can be set on the `AxiStreamSink` constructor. This causes calls to `read()` and `recv()` to return data sorted by `tid` ot `tdest`, returned in order of transaction completion time.

axis_sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "m_axis"), dut.clk, dut.rst, interleave="tid")
data = await axis_sink.recv()

#### Signals

* `tdata`: data, required
Expand Down
126 changes: 88 additions & 38 deletions cocotbext/axi/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from .version import __version__
from .reset import Reset

from functools import reduce
from random import choice


class AxiStreamFrame:
def __init__(self, tdata=b'', tkeep=None, tid=None, tdest=None, tuser=None, tx_complete=None):
Expand Down Expand Up @@ -261,9 +264,10 @@ class AxiStreamBase(Reset):
_ready_init = None

def __init__(self, bus, clock, reset=None, reset_active_level=True,
byte_size=None, byte_lanes=None, *args, **kwargs):
byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs):

self.bus = bus
self.interleave = interleave
self.clock = clock
self.reset = reset
self.log = logging.getLogger(f"cocotb.{bus._entity._name}.{bus._name}")
Expand All @@ -275,10 +279,15 @@ def __init__(self, bus, clock, reset=None, reset_active_level=True,

super().__init__(*args, **kwargs)

if "tid" in self.interleave and not hasattr(self.bus, "tid"):
raise ValueError("Cannot interleave with tid on a bus without tid")
if "tdest" in self.interleave and not hasattr(self.bus, "tdest"):
raise ValueError("Cannot interleave with tdest on a bus without tdest")

self.active = False
self.queue = Queue()
self.dequeue_event = Event()
self.current_frame = None
self.current_frames = {}
self.idle_event = Event()
self.idle_event.set()
self.active_event = Event()
Expand Down Expand Up @@ -425,14 +434,20 @@ class AxiStreamSource(AxiStreamBase, AxiStreamPause):
_ready_init = None

def __init__(self, bus, clock, reset=None, reset_active_level=True,
byte_size=None, byte_lanes=None, *args, **kwargs):
byte_size=None, byte_lanes=None, interleave=None, max_interleave_depth=None, *args, **kwargs):

super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs)
super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave, *args, **kwargs)

self.max_interleave_depth = max_interleave_depth
self.queue_occupancy_limit_bytes = -1
self.queue_occupancy_limit_frames = -1

async def send(self, frame):
# If interleaving enabled, check provided frame has the required parameter(s)
if "tid" in self.interleave and (frame.tid is None or type(frame.tid) is list):
raise ValueError("Sending a frame with interleaving on tid requires single tid be associated with the frame")
if "dest" in self.interleave and (frame.tdest is None or type(frame.tdest) is list):
raise ValueError("Sending a frame with interleaving on tdest requires single tdest be associated with the frame")
while self.full():
self.dequeue_event.clear()
await self.dequeue_event.wait()
Expand All @@ -444,6 +459,11 @@ async def send(self, frame):
self.queue_occupancy_frames += 1

def send_nowait(self, frame):
# If interleaving enabled, check provided frame has the required parameter(s)
if "tid" in self.interleave and (frame.tid is None or type(frame.tid) is list):
raise ValueError("Sending a frame with interleaving on tid requires single tid be associated with the frame")
if "dest" in self.interleave and (frame.tdest is None or type(frame.tdest) is list):
raise ValueError("Sending a frame with interleaving on tdest requires single tdest be associated with the frame")
if self.full():
raise QueueFull()
frame = AxiStreamFrame(frame)
Expand Down Expand Up @@ -491,14 +511,19 @@ def _handle_reset(self, state):
if hasattr(self.bus, "tuser"):
self.bus.tuser.value = 0

if self.current_frame:
self.log.warning("Flushed transmit frame during reset: %s", self.current_frame)
self.current_frame.handle_tx_complete()
self.current_frame = None
for current_frame in self.current_frames.values():
self.log.warning("Flushed transmit frame during reset: %s", current_frame)
current_frame.handle_tx_complete()
self.current_frames = {}

async def _run(self):
frame = None
frame_offset = 0
# next frame hold the most recently popped frame from the Queue
# It may be held if the number of entries in frames is >= max_interleave_depth
next_frame = None
# Frames holds the in-flight frame for each of the interleaved stream
frames = {}
frame_offsets = {}

self.active = False

has_tready = hasattr(self.bus, "tready")
Expand All @@ -519,18 +544,35 @@ async def _run(self):
tvalid_sample = (not has_tvalid) or self.bus.tvalid.value

if (tready_sample and tvalid_sample) or not tvalid_sample:
if not frame and not self.queue.empty():
frame = self.queue.get_nowait()
self.dequeue_event.set()
self.queue_occupancy_bytes -= len(frame)
self.queue_occupancy_frames -= 1
self.current_frame = frame
frame.sim_time_start = get_sim_time()
frame.sim_time_end = None
self.log.info("TX frame: %s", frame)
frame.normalize()
self.active = True
frame_offset = 0

# Pop a frame from the queue if we have space
if not next_frame and not self.queue.empty():
next_frame = self.queue.get_nowait()

# Schedule the previously popped frame if that doesn't exceed our limits
if next_frame and (self.max_interleave_depth is None or len(frames) < self.max_interleave_depth):
k = (next_frame.tid if "tid" in self.interleave else None, next_frame.tdest if "tdest" in self.interleave else None)
if frames.get(k) == None:
frame = next_frame
next_frame = None
self.dequeue_event.set()
self.queue_occupancy_bytes -= len(frame)
self.queue_occupancy_frames -= 1
self.current_frames[k] = frame
frame.sim_time_start = get_sim_time()
frame.sim_time_end = None
self.log.info("TX frame: %s", frame)
frame.normalize()
self.active = True
frames[k] = frame
frame_offsets[k] = 0

frame = None
k = None
frame_offset = 0
if frames:
k, frame = choice(frames)
frame_offset = frame_offsets[k]

if frame and not self.pause:
tdata_val = 0
Expand All @@ -547,15 +589,17 @@ async def _run(self):
tdest_val = frame.tdest[frame_offset]
tuser_val = frame.tuser[frame_offset]
frame_offset += 1
frame_offsets[k] = frame_offset

if frame_offset >= len(frame.tdata):
tlast_val = 1
frame.sim_time_end = get_sim_time()
frame.handle_tx_complete()
frame = None
self.current_frame = None
del frames[k]
del self.current_frames[k]
del frame_offsets[k]
break

self.bus.tdata.value = tdata_val
if has_tvalid:
self.bus.tvalid.value = 1
Expand All @@ -574,8 +618,8 @@ async def _run(self):
self.bus.tvalid.value = 0
if has_tlast:
self.bus.tlast.value = 0
self.active = bool(frame)
if not frame and self.queue.empty():
self.active = bool(frames)
if not frames and self.empty():
self.idle_event.set()
self.active_event.clear()

Expand All @@ -592,9 +636,9 @@ class AxiStreamMonitor(AxiStreamBase):
_ready_init = None

def __init__(self, bus, clock, reset=None, reset_active_level=True,
byte_size=None, byte_lanes=None, *args, **kwargs):
byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs):

super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs)
super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave, *args, **kwargs)

self.read_queue = []

Expand Down Expand Up @@ -666,7 +710,7 @@ async def _run_tready_monitor(self):
self.wake_event.set()

async def _run(self):
frame = None
frames = {}
self.active = False

has_tready = hasattr(self.bus, "tready")
Expand All @@ -689,6 +733,9 @@ async def _run(self):
tvalid_sample = (not has_tvalid) or self.bus.tvalid.value

if tready_sample and tvalid_sample:
k = (self.bus.tid.value if "tid" in self.interleave else None, self.bus.tdest.value if "tdest" in self.interleave else None)
frame = frames.pop(k)

if not frame:
if self.byte_size == 8:
frame = AxiStreamFrame(bytearray(), [], [], [], [])
Expand Down Expand Up @@ -717,8 +764,8 @@ async def _run(self):

self.queue.put_nowait(frame)
self.active_event.set()

frame = None
else:
frames[k] = frame
else:
self.active = bool(frame)

Expand All @@ -736,12 +783,12 @@ class AxiStreamSink(AxiStreamMonitor, AxiStreamPause):
_ready_init = 0

def __init__(self, bus, clock, reset=None, reset_active_level=True,
byte_size=None, byte_lanes=None, *args, **kwargs):
byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs):

self.queue_occupancy_limit_bytes = -1
self.queue_occupancy_limit_frames = -1

super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs)
super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave *args, **kwargs)

def full(self):
if self.queue_occupancy_limit_bytes > 0 and self.queue_occupancy_bytes > self.queue_occupancy_limit_bytes:
Expand All @@ -765,7 +812,7 @@ def _dequeue(self, frame):
self.wake_event.set()

async def _run(self):
frame = None
frames = {}
self.active = False

has_tready = hasattr(self.bus, "tready")
Expand All @@ -790,6 +837,9 @@ async def _run(self):
tvalid_sample = (not has_tvalid) or self.bus.tvalid.value

if tready_sample and tvalid_sample:
k = (self.bus.tid.value if "tid" in self.interleave else None, self.bus.tdest.value if "tdest" in self.interleave else None)
frame = frames.pop(k)

if not frame:
if self.byte_size == 8:
frame = AxiStreamFrame(bytearray(), [], [], [], [])
Expand Down Expand Up @@ -818,10 +868,10 @@ async def _run(self):

self.queue.put_nowait(frame)
self.active_event.set()

frame = None
else:
frames[k] = frame
else:
self.active = bool(frame)
self.active = reduce(lambda r, f: r or bool(f), frames, False)

if has_tready:
self.bus.tready.value = (not self.full() and not pause_sample)
Expand Down
2 changes: 1 addition & 1 deletion tests/axis/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

TOPLEVEL_LANG = verilog

SIM ?= icarus
SIM ?= verilator
ollie-etl marked this conversation as resolved.
Show resolved Hide resolved
WAVES ?= 0

COCOTB_HDL_TIMEUNIT = 1ns
Expand Down
13 changes: 7 additions & 6 deletions tests/axis/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@


class TB:
def __init__(self, dut):
def __init__(self, dut, interleave):
self.dut = dut

self.log = logging.getLogger("cocotb.tb")
self.log.setLevel(logging.DEBUG)

cocotb.start_soon(Clock(dut.clk, 2, units="ns").start())

self.source = AxiStreamSource(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst)
self.sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst)
self.monitor = AxiStreamMonitor(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst)
self.source = AxiStreamSource(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave)
self.sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave)
self.monitor = AxiStreamMonitor(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave)

def set_idle_generator(self, generator=None):
if generator:
Expand All @@ -71,9 +71,9 @@ async def reset(self):
await RisingEdge(self.dut.clk)


async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=None, backpressure_inserter=None):
async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=None, backpressure_inserter=None, interleave=None):

tb = TB(dut)
tb = TB(dut,interleave)

id_count = 2**len(tb.source.bus.tid)

Expand Down Expand Up @@ -141,6 +141,7 @@ def incrementing_payload(length):
factory.add_option("payload_lengths", [size_list])
factory.add_option("payload_data", [incrementing_payload])
factory.add_option("idle_inserter", [None, cycle_pause])
factory.add_option("interleave", [None, "tid", "tdest", {"tid", "tdest"} ])
factory.add_option("backpressure_inserter", [None, cycle_pause])
factory.generate_tests()

Expand Down