diff --git a/hitl/streaming.py b/hitl/streaming.py index ea63fb2d2..e56b66a07 100644 --- a/hitl/streaming.py +++ b/hitl/streaming.py @@ -17,8 +17,8 @@ async def _main(): help="The MQTT topic prefix of the target") parser.add_argument("--broker", "-b", default="mqtt", type=str, help="The MQTT broker address") - parser.add_argument("--host", default="0.0.0.0", - help="Local address to listen on") + parser.add_argument("--multicast", "-m", default="239.192.1.100", type=str, + help="The multicast address to use for streaming") parser.add_argument("--port", type=int, default=9293, help="Local port to listen on") parser.add_argument("--duration", type=float, default=10., @@ -34,11 +34,16 @@ async def _main(): logger.info("Starting stream") await conf.set( - "/stream_target", {"ip": local_ip, "port": args.port}, retain=False) + "/stream_target", { + "ip": [int(x) for x in args.multicast.split('.')], + "port": args.port + }, retain=False) try: logger.info("Testing stream reception") - _transport, stream = await StabilizerStream.open((args.host, args.port)) + _transport, stream = await StabilizerStream.open(args.multicast, + args.port, + args.broker) loss = await measure(stream, args.duration) if loss > args.max_loss: raise RuntimeError("High frame loss", loss) diff --git a/py/stabilizer/stream.py b/py/stabilizer/stream.py index 6adc6f66c..7660f86a9 100644 --- a/py/stabilizer/stream.py +++ b/py/stabilizer/stream.py @@ -87,11 +87,23 @@ class StabilizerStream(asyncio.DatagramProtocol): } @classmethod - async def open(cls, local_addr, maxsize=1): - """Open a UDP socket and start receiving frames""" + async def open(cls, multicast_addr, multicast_port, broker, maxsize=1): + """Open a UDP multicast socket and start receiving frames""" loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - lambda: cls(maxsize), local_addr=local_addr) + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) + + # We need to specify which interface to receive broadcasts from, or Windows may choose the + # wrong one. Thus, use the broker address to figure out our local address for the interface + # of interest. + group = socket.inet_aton(multicast_addr) + iface = socket.inet_aton('.'.join([str(x) for x in get_local_ip(broker)])) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, group + iface) + + sock.bind(('', multicast_port)) + + transport, protocol = await loop.create_datagram_endpoint(lambda: cls(maxsize), sock=sock) # Increase the OS UDP receive buffer size to 4 MiB so that latency # spikes don't impact much. Achieving 4 MiB may require increasing # the max allowed buffer size, e.g. via @@ -155,7 +167,7 @@ async def _record(): pass logger.info("Received %g MB, %g MB/s", stat.bytes/1e6, - stat.bytes/1e6/duration) + stat.bytes/1e6/duration) sent = stat.received + stat.lost if sent: @@ -173,6 +185,8 @@ async def main(): help="Local port to listen on") parser.add_argument("--host", default="0.0.0.0", help="Local address to listen on") + parser.add_argument("--broker", default="mqtt", + help="The MQTT broker address") parser.add_argument("--maxsize", type=int, default=1, help="Frame queue size") parser.add_argument("--duration", type=float, default=1., @@ -181,7 +195,7 @@ async def main(): logging.basicConfig(level=logging.INFO) _transport, stream = await StabilizerStream.open( - (args.host, args.port), args.maxsize) + (args.host, args.port), args.broker, args.maxsize) await measure(stream, args.duration)