From ecf64e7a56ee85e10a812139a4aee09e736aa241 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Nov 2020 22:36:21 +0100 Subject: [PATCH] Handle non-contiguous memoryviews in C extension. This avoids the special-case in Python code. --- src/websockets/frames.py | 11 ++------ src/websockets/speedups.c | 51 ++++++++++++++++++----------------- tests/legacy/test_protocol.py | 30 --------------------- tests/test_frames.py | 9 ------- tests/test_utils.py | 24 +++-------------- 5 files changed, 32 insertions(+), 93 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 74223c0e8..71783e176 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -263,13 +263,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): return OP_TEXT, data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): + elif isinstance(data, (bytes, bytearray, memoryview)): return OP_BINARY, data - elif isinstance(data, memoryview): - if data.c_contiguous: - return OP_BINARY, data - else: - return OP_BINARY, data.tobytes() else: raise TypeError("data must be bytes-like or str") @@ -290,10 +285,8 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): return data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): + elif isinstance(data, (bytes, bytearray, memoryview)): return bytes(data) - elif isinstance(data, memoryview): - return data.tobytes() else: raise TypeError("data must be bytes-like or str") diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index ede181e5d..fc328e528 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -13,39 +13,35 @@ static const Py_ssize_t MASK_LEN = 4; /* Similar to PyBytes_AsStringAndSize, but accepts more types */ static int -_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length) +_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length) { - // This supports bytes, bytearrays, and C-contiguous memoryview objects, - // which are the most useful data structures for handling byte streams. - // websockets.framing.prepare_data() returns only values of these types. - // Any object implementing the buffer protocol could be supported, however - // that would require allocation or copying memory, which is expensive. + // This supports bytes, bytearrays, and memoryview objects, + // which are common data structures for handling byte streams. + // websockets.framing.prepare_data() returns only these types. + // If *tmp isn't NULL, the caller gets a new reference. if (PyBytes_Check(obj)) { + *tmp = NULL; *buffer = PyBytes_AS_STRING(obj); *length = PyBytes_GET_SIZE(obj); } else if (PyByteArray_Check(obj)) { + *tmp = NULL; *buffer = PyByteArray_AS_STRING(obj); *length = PyByteArray_GET_SIZE(obj); } else if (PyMemoryView_Check(obj)) { - Py_buffer *mv_buf; - mv_buf = PyMemoryView_GET_BUFFER(obj); - if (PyBuffer_IsContiguous(mv_buf, 'C')) - { - *buffer = mv_buf->buf; - *length = mv_buf->len; - } - else + *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C'); + if (*tmp == NULL) { - PyErr_Format( - PyExc_TypeError, - "expected a contiguous memoryview"); return -1; } + Py_buffer *mv_buf; + mv_buf = PyMemoryView_GET_BUFFER(*tmp); + *buffer = mv_buf->buf; + *length = mv_buf->len; } else { @@ -74,15 +70,17 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // A pointer to a char * + length will be extracted from the data and mask // arguments, possibly via a Py_buffer. + PyObject *input_tmp = NULL; char *input; Py_ssize_t input_len; + PyObject *mask_tmp = NULL; char *mask; Py_ssize_t mask_len; // Initialize a PyBytesObject then get a pointer to the underlying char * // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. - PyObject *result; + PyObject *result = NULL; char *output; // Other variables. @@ -94,23 +92,23 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) if (!PyArg_ParseTupleAndKeywords( args, kwds, "OO", kwlist, &input_obj, &mask_obj)) { - return NULL; + goto exit; } - if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1) + if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1) { - return NULL; + goto exit; } - if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1) + if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1) { - return NULL; + goto exit; } if (mask_len != MASK_LEN) { PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); - return NULL; + goto exit; } // Create output. @@ -118,7 +116,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) result = PyBytes_FromStringAndSize(NULL, input_len); if (result == NULL) { - return NULL; + goto exit; } // Since we juste created result, we don't need error checks. @@ -172,6 +170,9 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; } +exit: + Py_XDECREF(input_tmp); + Py_XDECREF(mask_tmp); return result; } diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 218d05376..a89bcc88b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -580,10 +580,6 @@ def test_send_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_BINARY, b"tea") - def test_send_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - def test_send_dict(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) @@ -624,14 +620,6 @@ def test_send_iterable_binary_from_memoryview(self): (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) - def test_send_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - def test_send_empty_iterable(self): self.loop.run_until_complete(self.protocol.send([])) self.assertNoFrameSent() @@ -697,16 +685,6 @@ def test_send_async_iterable_binary_from_memoryview(self): (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) - def test_send_async_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send( - async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - def test_send_empty_async_iterable(self): self.loop.run_until_complete(self.protocol.send(async_iterable([]))) self.assertNoFrameSent() @@ -799,10 +777,6 @@ def test_ping_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PING, b"tea") - def test_ping_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PING, b"tea") - def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) @@ -856,10 +830,6 @@ def test_pong_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PONG, b"tea") - def test_pong_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PONG, b"tea") - def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) diff --git a/tests/test_frames.py b/tests/test_frames.py index 4d10c6ef2..13a712322 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -218,12 +218,6 @@ def test_prepare_data_memoryview(self): (OP_BINARY, memoryview(b"tea")), ) - def test_prepare_data_non_contiguous_memoryview(self): - self.assertEqual( - prepare_data(memoryview(b"tteeaa")[::2]), - (OP_BINARY, b"tea"), - ) - def test_prepare_data_list(self): with self.assertRaises(TypeError): prepare_data([]) @@ -246,9 +240,6 @@ def test_prepare_ctrl_bytearray(self): def test_prepare_ctrl_memoryview(self): self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") - def test_prepare_ctrl_non_contiguous_memoryview(self): - self.assertEqual(prepare_ctrl(memoryview(b"tteeaa")[::2]), b"tea") - def test_prepare_ctrl_list(self): with self.assertRaises(TypeError): prepare_ctrl([]) diff --git a/tests/test_utils.py b/tests/test_utils.py index b490c2409..a9ea8dcbd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -43,21 +43,18 @@ def test_apply_mask(self): self.assertEqual(result, data_out) def test_apply_mask_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: + for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) + data_in, mask = memoryview(data_in), mask_type(mask) with self.subTest(data_in=data_in, mask=mask): result = self.apply_mask(data_in, mask) self.assertEqual(result, data_out) def test_apply_mask_non_contiguous_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: + for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) - data_in, mask = data_in[::-1], mask[::-1] + data_in, mask = memoryview(data_in)[::-1], mask_type(mask)[::-1] data_out = data_out[::-1] with self.subTest(data_in=data_in, mask=mask): @@ -92,16 +89,3 @@ class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): return c_apply_mask(*args, **kwargs) - - def test_apply_mask_non_contiguous_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: - for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) - data_in, mask = data_in[::-1], mask[::-1] - data_out = data_out[::-1] - - with self.subTest(data_in=data_in, mask=mask): - # The C extension only supports contiguous memoryviews. - with self.assertRaises(TypeError): - self.apply_mask(data_in, mask)