Skip to content

Commit

Permalink
Merge pull request #221 from ehagerty/will_set_bytes
Browse files Browse the repository at this point in the history
refactor will_set function to match the publish function and allow the msg/payload to be encoded bytes, not just str, int or float.
  • Loading branch information
FoamyGuy authored Aug 25, 2024
2 parents 5f222c2 + 9c0ecfc commit 617fff7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
62 changes: 49 additions & 13 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,40 +267,76 @@ def mqtt_msg(self, msg_size: int) -> None:
if msg_size < MQTT_MSG_MAX_SZ:
self._msg_size_lim = msg_size

# pylint: disable=too-many-branches, too-many-statements
def will_set(
self,
topic: Optional[str] = None,
payload: Optional[Union[int, float, str]] = None,
qos: int = 0,
topic: str,
msg: Union[str, int, float, bytes],
retain: bool = False,
qos: int = 0,
) -> None:
"""Sets the last will and testament properties. MUST be called before `connect()`.
:param str topic: MQTT Broker topic.
:param int|float|str payload: Last will disconnection payload.
payloads of type int & float are converted to a string.
:param str|int|float|bytes msg: Last will disconnection msg.
msgs of type int & float are converted to a string.
msgs of type byetes are left unchanged, as it is in the publish function.
:param int qos: Quality of Service level, defaults to
zero. Conventional options are ``0`` (send at most once), ``1``
(send at least once), or ``2`` (send exactly once).
.. note:: Only options ``1`` or ``0`` are QoS levels supported by this library.
:param bool retain: Specifies if the payload is to be retained when
:param bool retain: Specifies if the msg is to be retained when
it is published.
"""
self.logger.debug("Setting last will properties")
self._valid_qos(qos)
if self._is_connected:
raise MMQTTException("Last Will should only be called before connect().")
if payload is None:
payload = ""
if isinstance(payload, (int, float, str)):
payload = str(payload).encode()

# check topic/msg/qos kwargs
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")

if msg is None:
raise MMQTTException("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")

self._valid_qos(qos)
assert (
0 <= qos <= 1
), "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])

# variable header = 2-byte Topic length (big endian)
pub_hdr_var = bytearray(struct.pack(">H", len(topic.encode("utf-8"))))
pub_hdr_var.extend(topic.encode("utf-8")) # Topic name

remaining_length = 2 + len(msg) + len(topic.encode("utf-8"))
if qos > 0:
# packet identifier where QoS level is 1 or 2. [3.3.2.2]
remaining_length += 2
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
pub_hdr_var.append(self._pid >> 8)
pub_hdr_var.append(self._pid & 0xFF)

self._encode_remaining_length(pub_hdr_fixed, remaining_length)

self._lw_qos = qos
self._lw_topic = topic
self._lw_msg = payload
self._lw_msg = msg
self._lw_retain = retain
self.logger.debug("Last will properties successfully set")

def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
"""Registers a callback_method for a specific MQTT topic.
Expand Down
6 changes: 2 additions & 4 deletions adafruit_minimqtt/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def rec(node: MQTTMatcher.Node, i: int = 0):
else:
part = lst[i]
if part in node.children:
for content in rec(node.children[part], i + 1):
yield content
yield from rec(node.children[part], i + 1)
if "+" in node.children and (normal or i > 0):
for content in rec(node.children["+"], i + 1):
yield content
yield from rec(node.children["+"], i + 1)
if "#" in node.children and (normal or i > 0):
content = node.children["#"].content
if content is not None:
Expand Down

0 comments on commit 617fff7

Please sign in to comment.