From fef13b850a394348d73e45468fe429f46834f2ff Mon Sep 17 00:00:00 2001 From: mndza Date: Mon, 15 Jan 2024 16:21:32 +0100 Subject: [PATCH] gateware.usb.usb2.request: add `claim` to `RequestHandlerInterface` Introduce a new `claim` signal to the `RequestHandlerInterface`. If a `USBRequestHandler` wants to manage an incoming request, it must assert this signal to gain control of the remaining interface outputs. Additionally, this commit simplifies the logic within `USBRequestHandlerMultiplexer` and provides a fallback mechanism for unhandled requests. --- examples/usb/vendor_request.py | 14 ++--- luna/gateware/usb/devices/acm.py | 16 ++---- luna/gateware/usb/request/standard.py | 7 ++- luna/gateware/usb/usb2/control.py | 16 +----- luna/gateware/usb/usb2/request.py | 81 ++++++++++++++------------- 5 files changed, 57 insertions(+), 77 deletions(-) diff --git a/examples/usb/vendor_request.py b/examples/usb/vendor_request.py index 80db7b092..45e178ee6 100755 --- a/examples/usb/vendor_request.py +++ b/examples/usb/vendor_request.py @@ -39,6 +39,9 @@ def elaborate(self, platform): # to a user provided value with m.Case(self.REQUEST_SET_LEDS): + # Drive interface outputs for this request + m.d.comb += interface.claim.eq(1) + # If we have an active data byte, splat it onto the LEDs. # # For simplicity of this example, we'll accept any byte in @@ -56,16 +59,7 @@ def elaborate(self, platform): with m.If(interface.status_requested): m.d.comb += self.send_zlp() - - with m.Case(): - - # - # Stall unhandled requests. - # - with m.If(interface.status_requested | interface.data_requested): - m.d.comb += interface.handshakes_out.stall.eq(1) - - return m + return m diff --git a/luna/gateware/usb/devices/acm.py b/luna/gateware/usb/devices/acm.py index b0d4d7bcc..f9ad0a9c5 100644 --- a/luna/gateware/usb/devices/acm.py +++ b/luna/gateware/usb/devices/acm.py @@ -45,6 +45,9 @@ def elaborate(self, platform): # SET_LINE_CODING: The host attempts to tell us how it wants serial data # encoding. Since we output a stream, we'll ignore the actual line coding. with m.Case(self.SET_LINE_CODING): + + # Drive interface outputs for this request + m.d.comb += interface.claim.eq(1) # Always ACK the data out... with m.If(interface.rx_ready_for_response): @@ -53,17 +56,8 @@ def elaborate(self, platform): # ... and accept whatever the request was. with m.If(interface.status_requested): m.d.comb += self.send_zlp() - - - with m.Case(): - - # - # Stall unhandled requests. - # - with m.If(interface.status_requested | interface.data_requested): - m.d.comb += interface.handshakes_out.stall.eq(1) - - return m + + return m class USBSerialDevice(Elaboratable): diff --git a/luna/gateware/usb/request/standard.py b/luna/gateware/usb/request/standard.py index e46cd6845..23c608a07 100644 --- a/luna/gateware/usb/request/standard.py +++ b/luna/gateware/usb/request/standard.py @@ -90,6 +90,11 @@ def elaborate(self, platform): # Handlers. # with m.If(setup.type == USBRequestType.STANDARD): + + # Only handle setup packet if not blacklisted + blacklisted = functools.reduce(operator.__or__, (f(setup) for f in self._blacklist), Const(0)) + m.d.comb += interface.claim.eq(~blacklisted) + with m.FSM(domain="usb"): # IDLE -- not handling any active request @@ -106,8 +111,6 @@ def elaborate(self, platform): # If we've received a new setup packet, handle it. with m.If(setup.received): - # Only handle setup packet if not blacklisted - blacklisted = functools.reduce(operator.__or__, (f(setup) for f in self._blacklist), Const(0)) with m.If(~blacklisted): # Select which standard packet we're going to handler. diff --git a/luna/gateware/usb/usb2/control.py b/luna/gateware/usb/usb2/control.py index 9e45a56dd..7df6c14a7 100644 --- a/luna/gateware/usb/usb2/control.py +++ b/luna/gateware/usb/usb2/control.py @@ -15,7 +15,7 @@ from .packet import USBTokenDetector, TokenDetectorInterface from .packet import InterpacketTimerInterface, HandshakeExchangeInterface from .endpoint import EndpointInterface -from .request import USBSetupDecoder, USBRequestHandlerMultiplexer, StallOnlyRequestHandler +from .request import USBSetupDecoder, USBRequestHandlerMultiplexer from ..request.standard import StandardRequestHandler from ..stream import USBInStreamInterface, USBOutStreamInterface @@ -127,20 +127,6 @@ def elaborate(self, platform): m.d.comb += tokenizer.interface.connect(interface.tokenizer) - # - # Convenience feature: - # - # If we have -only- a standard request handler, automatically add a handler that will - # stall all other requests. - # - single_handler = (len(self._request_handlers) == 1) - if (single_handler and isinstance(self._request_handlers[0], StandardRequestHandler)): - - # Add a handler that will stall any non-standard request. - stall_condition = lambda setup : setup.type != USBRequestType.STANDARD - self.add_request_handler(StallOnlyRequestHandler(stall_condition)) - - # # Submodules # diff --git a/luna/gateware/usb/usb2/request.py b/luna/gateware/usb/usb2/request.py index 4cfd12298..ef0740fb7 100644 --- a/luna/gateware/usb/usb2/request.py +++ b/luna/gateware/usb/usb2/request.py @@ -11,6 +11,7 @@ import operator from amaranth import Signal, Module, Elaboratable, Cat +from amaranth.lib.coding import Encoder from amaranth.hdl.rec import Record, DIR_FANOUT from . import USBSpeed @@ -19,7 +20,6 @@ from .packet import InterpacketTimerInterface, HandshakeExchangeInterface from ..stream import USBInStreamInterface, USBOutStreamInterface from ..request import SetupPacket -from ...utils.bus import OneHotMultiplexer from ...test import usb_domain_test_case @@ -31,6 +31,7 @@ class RequestHandlerInterface: Components (I = input to request handler; O = output to control interface): *: setup -- Carries the most recent setup request to the handler. *: tokenizer -- Carries information about any incoming token packets. + O: claim -- Assert to drive the rest of output signals. # Control request status signals. I: data_requested -- Pulsed to indicate that a data-phase IN token has been issued, @@ -64,6 +65,7 @@ class RequestHandlerInterface: def __init__(self): self.setup = SetupPacket() self.tokenizer = TokenDetectorInterface() + self.claim = Signal() self.data_requested = Signal() self.status_requested = Signal() @@ -396,16 +398,19 @@ def __init__(self): # Internals # self._interfaces = [] + self._fallback = None def add_interface(self, interface: RequestHandlerInterface): """ Adds a RequestHandlerInterface to the multiplexer. - Arbitration is not performed; it's expected only one handler will be - driving requests at a time. + It's expected only one handler will be driving requests at a time. """ self._interfaces.append(interface) + def set_fallback_interface(self, interface: RequestHandlerInterface): + """ Sets a RequestHandlerInterface as a fallback for unhandled requests. """ + self._fallback = interface def _multiplex_signals(self, m, *, when, multiplex, sub_bus=None): """ Helper that creates a simple priority-encoder multiplexer. @@ -454,11 +459,15 @@ def elaborate(self, platform): m = Module() shared = self.shared + # If no fallback request handler is provided, stall all unhandled requests. + if self._fallback is None: + m.submodules.stall_handler = stall_handler = StallOnlyRequestHandler() + self._fallback = stall_handler.interface # # Pass through signals being routed -to- our pre-mux interfaces. # - for interface in self._interfaces: + for interface in [*self._interfaces, self._fallback]: m.d.comb += [ shared.setup .connect(interface.setup), shared.tokenizer .connect(interface.tokenizer), @@ -476,42 +485,36 @@ def elaborate(self, platform): # # Multiplex the signals being routed -from- our pre-mux interface. # - self._multiplex_signals(m, - when='address_changed', - multiplex=['address_changed', 'new_address'] - ) - self._multiplex_signals(m, - when='config_changed', - multiplex=['config_changed', 'new_config'] - ) + def _connect_interface_outputs(interface): + m.d.comb += [ + shared.tx .stream_eq(interface.tx), - # Connect up our transmit interface. - m.submodules.tx_mux = tx_mux = OneHotMultiplexer( - interface_type=USBInStreamInterface, - mux_signals=('payload',), - or_signals=('valid', 'first', 'last'), - pass_signals=('ready',) - ) - tx_mux.add_interfaces(i.tx for i in self._interfaces) - m.d.comb += self.shared.tx.stream_eq(tx_mux.output) - - # Pass through the relevant PID from our data source. - for i in self._interfaces: - with m.If(i.tx.valid): - m.d.comb += self.shared.tx_data_pid.eq(i.tx_data_pid) - - # OR together all of our handshake-generation requests. - any_ack = functools.reduce(operator.__or__, (i.handshakes_out.ack for i in self._interfaces)) - any_nak = functools.reduce(operator.__or__, (i.handshakes_out.nak for i in self._interfaces)) - any_stall = functools.reduce(operator.__or__, (i.handshakes_out.stall for i in self._interfaces)) - - m.d.comb += [ - shared.handshakes_out.ack .eq(any_ack), - shared.handshakes_out.nak .eq(any_nak), - shared.handshakes_out.stall .eq(any_stall), - ] + shared.tx_data_pid .eq(interface.tx_data_pid), + + shared.handshakes_out .eq(interface.handshakes_out), + + shared.address_changed .eq(interface.address_changed), + shared.new_address .eq(interface.new_address), + shared.config_changed .eq(interface.config_changed), + shared.new_config .eq(interface.new_config), + ] + # The encoder provides the index of the single interface that claims the + # output lines. Otherwise, it asserts the .n (invalid) line. + m.submodules.encoder = encoder = Encoder(len(self._interfaces)) + m.d.comb += encoder.i.eq(Cat(interface.claim for interface in self._interfaces)) + + # Connect the interface outputs to the interface that claims them. + with m.Switch(encoder.o): + for index, interface in enumerate(self._interfaces): + with m.Case(index): + _connect_interface_outputs(interface) + + # Use the fallback handler interface for the invalid case. + with m.If(encoder.n): + _connect_interface_outputs(self._fallback) + return m @@ -523,14 +526,14 @@ class StallOnlyRequestHandler(Elaboratable): See its record definition for signal definitions. """ - def __init__(self, stall_condition): + def __init__(self, stall_condition=None): """ Parameters: stall_condition -- A function that accepts a SetupRequest packet, and returns an Amaranth conditional indicating whether we should stall. """ - self.condition = stall_condition + self.condition = stall_condition or (lambda _: 1) # # I/O port