Skip to content

Commit

Permalink
Fix mqtt3 client shutdown callback (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiazhvera authored Aug 28, 2023
1 parent 8523a84 commit 74e0538
Show file tree
Hide file tree
Showing 5 changed files with 605 additions and 162 deletions.
12 changes: 10 additions & 2 deletions source/mqtt5_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,24 @@ struct mqtt5_client_binding {
/* Called on either failed client creation or by the client upon normal client termination */
static void s_mqtt5_client_on_terminate(void *user_data) {
struct mqtt5_client_binding *client = user_data;

PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}
if (client->client_core != NULL) {
// Make sure to release the python client object
Py_XDECREF(client->client_core);
}
aws_mem_release(aws_py_get_allocator(), client);
PyGILState_Release(state);
}

/* Called when capsule's refcount hits 0 */
static void s_mqtt5_python_client_destructor(PyObject *client_capsule) {
struct mqtt5_client_binding *client = PyCapsule_GetPointer(client_capsule, s_capsule_name_mqtt5_client);
assert(client);

Py_XDECREF(client->client_core);

if (client->native != NULL) {
/* If client is not NULL, it can be shutdown and cleaned normally */
aws_mqtt5_client_release(client->native);
Expand Down
96 changes: 48 additions & 48 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ struct mqtt_connection_binding {
* Lets us invoke callbacks on the python object without preventing the GC from cleaning it up. */
PyObject *self_proxy;

/* To not run into a segfault calling on_close with the connection being freed before the callback
* can be invoked, we need to keep the PyCapsule alive. */
PyObject *self_capsule;

PyObject *on_connect;
PyObject *on_any_publish;

Expand All @@ -62,24 +58,25 @@ struct mqtt_connection_binding {

static void s_mqtt_python_connection_finish_destruction(struct mqtt_connection_binding *py_connection) {

/* Do not call the on_stopped callback if the python object is finished/destroyed */
aws_mqtt_client_connection_set_connection_closed_handler(py_connection->native, NULL, NULL);

aws_mqtt_client_connection_release(py_connection->native);

Py_DECREF(py_connection->self_proxy);
Py_DECREF(py_connection->client);
Py_XDECREF(py_connection->on_any_publish);

aws_mem_release(aws_py_get_allocator(), py_connection);
}

static void s_mqtt_python_connection_destructor_on_disconnect(
struct aws_mqtt_client_connection *connection,
void *userdata) {
static void s_start_destroy_native(struct mqtt_connection_binding *py_connection) {
if (py_connection == NULL || py_connection->native == NULL) {
return;
}

if (connection == NULL || userdata == NULL) {
return; // The connection is dead - skip!
aws_mqtt_client_connection_release(py_connection->native);
}

static void s_mqtt_python_connection_termination(void *userdata) {

if (userdata == NULL) {
return; // The binding is dead - skip!
}

struct mqtt_connection_binding *py_connection = userdata;
Expand All @@ -93,20 +90,36 @@ static void s_mqtt_python_connection_destructor_on_disconnect(
PyGILState_Release(state);
}

static void s_mqtt_python_connection_destructor_on_disconnect(
struct aws_mqtt_client_connection *connection,
void *user_data) {
if (connection == NULL || user_data == NULL) {
return; // The connection is dead - skip!
}

struct mqtt_connection_binding *py_connection = user_data;
PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}
s_start_destroy_native(py_connection);
PyGILState_Release(state);
}

static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) {

struct mqtt_connection_binding *py_connection =
PyCapsule_GetPointer(connection_capsule, s_capsule_name_mqtt_client_connection);
assert(py_connection);
AWS_FATAL_ASSERT(py_connection);
AWS_FATAL_ASSERT(py_connection->native);

/* This is the destructor from Python - so we can ignore the closed callback here */
aws_mqtt_client_connection_set_connection_closed_handler(py_connection->native, NULL, NULL);

if (aws_mqtt_client_connection_disconnect(
py_connection->native, s_mqtt_python_connection_destructor_on_disconnect, py_connection)) {

/* If this returns an error, we should immediately destroy the connection */
s_mqtt_python_connection_finish_destruction(py_connection);
/* If we already disconnected, we should immediately release the native connection */
s_start_destroy_native(py_connection);
}
}

Expand Down Expand Up @@ -254,15 +267,6 @@ static void s_on_connection_closed(
PyErr_WriteUnraisable(PyErr_Occurred());
}
}
Py_DECREF(py_connection->self_proxy);

/** Allow the PyCapsule to be freed like normal again.
* If this is the last reference (I.E customer code called disconnect and threw the Python object away)
* Then this will allow the MQTT311 class to be fully cleaned.
* If it is not the last reference (customer still has reference) then when the customer is done
* it will be freed like normal.
**/
Py_DECREF(py_connection->self_capsule);

PyGILState_Release(state);
}
Expand All @@ -272,6 +276,7 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {

struct aws_allocator *allocator = aws_py_get_allocator();

PyObject *self_proxy;
PyObject *self_py;
PyObject *client_py;
PyObject *use_websocket_py;
Expand Down Expand Up @@ -310,13 +315,19 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
}
if (!py_connection->native) {
PyErr_SetAwsLastError();
goto connection_new_failed;
goto on_error;
}

if (aws_mqtt_client_connection_set_connection_termination_handler(
py_connection->native, s_mqtt_python_connection_termination, py_connection)) {
PyErr_SetAwsLastError();
goto on_error;
}

if (aws_mqtt_client_connection_set_connection_result_handlers(
py_connection->native, s_on_connection_success, py_connection, s_on_connection_failure, py_connection)) {
PyErr_SetAwsLastError();
goto set_connection_handlers_failed;
goto on_error;
}

if (aws_mqtt_client_connection_set_connection_interruption_handlers(
Expand All @@ -327,13 +338,13 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
py_connection)) {

PyErr_SetAwsLastError();
goto set_interruption_failed;
goto on_error;
}

if (aws_mqtt_client_connection_set_connection_closed_handler(
py_connection->native, s_on_connection_closed, py_connection)) {
PyErr_SetAwsLastError();
goto set_interruption_failed;
goto on_error;
}

if (PyObject_IsTrue(use_websocket_py)) {
Expand All @@ -345,39 +356,32 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
NULL /*validator userdata*/)) {

PyErr_SetAwsLastError();
goto use_websockets_failed;
goto on_error;
}
}

PyObject *self_proxy = PyWeakref_NewProxy(self_py, NULL);
self_proxy = PyWeakref_NewProxy(self_py, NULL);
if (!self_proxy) {
goto proxy_new_failed;
goto on_error;
}

PyObject *capsule =
PyCapsule_New(py_connection, s_capsule_name_mqtt_client_connection, s_mqtt_python_connection_destructor);
if (!capsule) {
goto capsule_new_failed;
goto on_error;
}

/* From hereon, nothing will fail */

py_connection->self_capsule = capsule;
py_connection->self_proxy = self_proxy;

py_connection->client = client_py;
Py_INCREF(py_connection->client);

return capsule;

capsule_new_failed:
Py_DECREF(self_proxy);
proxy_new_failed:
use_websockets_failed:
set_interruption_failed:
set_connection_handlers_failed:
on_error:
Py_XDECREF(self_proxy);
aws_mqtt_client_connection_release(py_connection->native);
connection_new_failed:
aws_mem_release(allocator, py_connection);
return NULL;
}
Expand Down Expand Up @@ -1329,14 +1333,10 @@ PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *arg
}

Py_INCREF(on_disconnect);
Py_INCREF(connection->self_proxy); /* We need to keep self_proxy alive for on_closed, which will dec-ref this */
Py_INCREF(connection->self_capsule); /* Do not allow the PyCapsule to be freed, we need it alive for on_closed */

int err = aws_mqtt_client_connection_disconnect(connection->native, s_on_disconnect, on_disconnect);
if (err) {
Py_DECREF(on_disconnect);
Py_DECREF(connection->self_proxy);
Py_DECREF(connection->self_capsule);
return PyErr_AwsLastError();
}

Expand Down
111 changes: 0 additions & 111 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,117 +1549,6 @@ def test_operation_statistics_uc1(self):
client.stop()
callbacks.future_stopped.result(TIMEOUT)

# ==============================================================
# 5to3 ADAPTER TEST CASES
# ==============================================================
def test_5to3Adapter_connection_creation_minimum(self):
client5 = self._create_client()
connection = client5.new_connection()

def test_5to3Adapter_connection_creation_maximum(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")

user_properties = []
user_properties.append(mqtt5.UserProperty(name="name1", value="value1"))
user_properties.append(mqtt5.UserProperty(name="name2", value="value2"))

publish_packet = mqtt5.PublishPacket(
payload="TEST_PAYLOAD",
qos=mqtt5.QoS.AT_LEAST_ONCE,
retain=False,
topic="TEST_TOPIC",
payload_format_indicator=mqtt5.PayloadFormatIndicator.AWS_MQTT5_PFI_UTF8,
message_expiry_interval_sec=10,
topic_alias=1,
response_topic="TEST_RESPONSE_TOPIC",
correlation_data="TEST_CORRELATION_DATA",
content_type="TEST_CONTENT_TYPE",
user_properties=user_properties
)

connect_options = mqtt5.ConnectPacket(
keep_alive_interval_sec=10,
client_id="TEST_CLIENT",
username="USERNAME",
password="PASSWORD",
session_expiry_interval_sec=100,
request_response_information=1,
request_problem_information=1,
receive_maximum=1000,
maximum_packet_size=10000,
will_delay_interval_sec=1000,
will=publish_packet,
user_properties=user_properties
)
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883,
connect_options=connect_options,
session_behavior=mqtt5.ClientSessionBehaviorType.CLEAN,
extended_validation_and_flow_control_options=mqtt5.ExtendedValidationAndFlowControlOptions.AWS_IOT_CORE_DEFAULTS,
offline_queue_behavior=mqtt5.ClientOperationQueueBehaviorType.FAIL_ALL_ON_DISCONNECT,
retry_jitter_mode=mqtt5.ExponentialBackoffJitterMode.DECORRELATED,
min_reconnect_delay_ms=100,
max_reconnect_delay_ms=50000,
min_connected_time_to_reset_reconnect_delay_ms=1000,
ping_timeout_ms=1000,
connack_timeout_ms=1000,
ack_timeout_sec=100)
client = self._create_client(client_options=client_options)
connection = client.new_connection()

def test_5to3Adapter_direct_connect_minimum(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_DIRECT_MQTT_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT5_DIRECT_MQTT_PORT"))

client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port
)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)

connection = client.new_connection()
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_5to3Adapter_websocket_connect_minimum(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_WS_MQTT_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT5_WS_MQTT_PORT"))

client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port
)
callbacks = Mqtt5TestCallbacks()
client_options.websocket_handshake_transform = callbacks.ws_handshake_transform

client = self._create_client(client_options=client_options, callbacks=callbacks)
connection = client.new_connection()
connection.connect().result(TIMEOUT)
callbacks.future_connection_success.result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_5to3Adapter_direct_connect_mutual_tls(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY")

client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883
)
tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path(
input_cert,
input_key
)
client_options.tls_ctx = io.ClientTlsContext(tls_ctx_options)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
connection = client.new_connection()
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)


if __name__ == 'main':
unittest.main()
Loading

0 comments on commit 74e0538

Please sign in to comment.