From d22c6d0d3331b6a62d47757a7f6fee6f3773a344 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Sun, 24 Nov 2024 21:21:04 +0000 Subject: [PATCH 1/5] adding unittest for base connection object --- surrealdb/connection.py | 132 +++++++++++++++++++++++++++---- surrealdb/connection_ws.py | 15 +++- surrealdb/data/README.md | 43 ++++++++++ tests/unit/test_connection.py | 92 +++++++++++++++++++++ tests/unit/test_ws_connection.py | 11 +++ 5 files changed, 275 insertions(+), 18 deletions(-) create mode 100644 surrealdb/data/README.md create mode 100644 tests/unit/test_connection.py diff --git a/surrealdb/connection.py b/surrealdb/connection.py index 16ee2c36..b5a98abe 100644 --- a/surrealdb/connection.py +++ b/surrealdb/connection.py @@ -1,17 +1,25 @@ +import logging import secrets import string -import logging import threading import uuid +from asyncio import Queue from dataclasses import dataclass - from typing import Dict, Tuple + from surrealdb.constants import REQUEST_ID_LENGTH from surrealdb.data.cbor import encode, decode -from asyncio import Queue class ResponseType: + """ + Enum-like class representing response types for the connection. + + Attributes: + SEND (int): Response type for standard requests. + NOTIFICATION (int): Response type for notifications. + ERROR (int): Response type for errors. + """ SEND = 1 NOTIFICATION = 2 ERROR = 3 @@ -19,12 +27,32 @@ class ResponseType: @dataclass class RequestData: + """ + Represents the data for a request sent over the connection. + + Attributes: + id (str): Unique identifier for the request. + method (str): The method name to invoke. + params (Tuple): Parameters for the method. + """ id: str method: str params: Tuple class Connection: + """ + Base class for managing a connection to the database. + + Manages request/response lifecycle, including the use of queues for + handling asynchronous communication. + + Attributes: + _queues (Dict[int, dict]): Mapping of response types to their queues. + _namespace (str | None): Current namespace in use. + _database (str | None): Current database in use. + _auth_token (str | None): Authentication token. + """ _queues: Dict[int, dict] _namespace: str | None _database: str | None @@ -35,6 +63,13 @@ def __init__( base_url: str, logger: logging.Logger, ): + """ + Initialize the Connection instance. + + Args: + base_url (str): The base URL of the server. + logger (logging.Logger): Logger for debugging and tracking activities. + """ self._locks = { ResponseType.SEND: threading.Lock(), ResponseType.NOTIFICATION: threading.Lock(), @@ -50,27 +85,79 @@ def __init__( self._logger = logger async def use(self, namespace: str, database: str) -> None: - pass + """ + Set the namespace and database for subsequent operations. + + Args: + namespace (str): The namespace to use. + database (str): The database to use. + """ + raise NotImplementedError("use method must be implemented") async def connect(self) -> None: - pass + """ + Establish a connection to the server. + """ + raise NotImplementedError("connect method must be implemented") async def close(self) -> None: - pass - - async def _make_request(self, request_data: RequestData, encoder, decoder): - pass - - async def set(self, key: str, value): - pass - - async def unset(self, key: str): - pass + """ + Close the connection to the server. + """ + raise NotImplementedError("close method must be implemented") + + async def _make_request(self, request_data: RequestData, encoder, decoder) -> dict: + """ + Internal method to send a request and handle the response. + + Args: + request_data (RequestData): The data to send. + encoder (function): Function to encode the request. + decoder (function): Function to decode the response. + return: + dict: The response data from the request. + """ + raise NotImplementedError("_make_request method must be implemented") + + async def set(self, key: str, value) -> None: + """ + Set a key-value pair in the database. + + Args: + key (str): The key to set. + value: The value to set. + """ + raise NotImplementedError("set method must be implemented") + + async def unset(self, key: str) -> None: + """ + Unset a key-value pair in the database. + + Args: + key (str): The key to unset. + """ + raise NotImplementedError("unset method must be implemented") def set_token(self, token: str | None = None) -> None: + """ + Set the authentication token for the connection. + + Args: + token (str): The authentication token to be set + """ self._auth_token = token - def create_response_queue(self, response_type: int, queue_id: str): + def create_response_queue(self, response_type: int, queue_id: str) -> Queue: + """ + Create a response queue for a given response type. + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + Returns: + Queue: The response queue for the given response type and queue ID + (existing queues will be overwritten if same ID is used, cannot get existing queue). + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) @@ -83,7 +170,18 @@ def create_response_queue(self, response_type: int, queue_id: str): self._queues[response_type] = response_type_queues return queue - def get_response_queue(self, response_type: int, queue_id: str): + def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None: + """ + Get a response queue for a given response type. + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + + Returns: + Queue: The response queue for the given response type and queue ID + (existing queues will be overwritten if same ID is used). + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) diff --git a/surrealdb/connection_ws.py b/surrealdb/connection_ws.py index 732e6d6c..ce87dc4a 100644 --- a/surrealdb/connection_ws.py +++ b/surrealdb/connection_ws.py @@ -36,10 +36,23 @@ async def use(self, namespace: str, database: str) -> None: await self.send("use", namespace, database) - async def set(self, key: str, value): + async def set(self, key: str, value) -> None: + """ + Set a key-value pair in the database. + + Args: + key (str): The key to set. + value: The value to set. + """ await self.send("let", key, value) async def unset(self, key: str): + """ + Unset a key-value pair in the database. + + Args: + key (str): The key to unset. + """ await self.send("unset", key) async def close(self): diff --git a/surrealdb/data/README.md b/surrealdb/data/README.md new file mode 100644 index 00000000..e6451294 --- /dev/null +++ b/surrealdb/data/README.md @@ -0,0 +1,43 @@ +## What is CBOR? +CBOR is a binary data serialization format similar to JSON but more compact, efficient, and capable of encoding a +broader range of data types. It is useful for exchanging structured data between systems, especially when performance +and size are critical. + +## Purpose of the CBOR Implementation + +The CBOR code here allows the custom SurrealDB types (e.g., `GeometryPoint`, `Table`, `Range`, etc.) to be serialized +into CBOR binary format and deserialized back into Python objects. This is necessary because these types are not natively +supported by CBOR; thus, custom encoding and decoding logic is implemented. + +## Key Components + +### Custom Types + +`Range` Class: Represents a range with a beginning (`begin`) and end (`end`). These can either be included (`BoundIncluded`) or excluded (`BoundExcluded`). +`Table`, `RecordID`, `GeometryPoint`, etc.: Custom SurrealDB-specific data types, representing domain-specific constructs like tables, records, and geometrical objects. + +### CBOR Encoder + +The function `default_encoder` is used to encode custom Python objects into CBOR's binary format. This is done by associating a specific CBOR tag (a numeric identifier) with each data type. + +For example: + +`GeometryPoint` objects are encoded using the tag `TAG_GEOMETRY_POINT` with its coordinates as the value. +`Range` objects are encoded using the tag `TAG_BOUND_EXCLUDED` with a list [begin, end] as its value. +The `CBORTag` class is used to represent tagged data in `CBOR`. + +### CBOR Decoder + +The function `tag_decoder` is the inverse of `default_encoder`. It takes tagged CBOR data and reconstructs the corresponding Python objects. + +For example: + +When encountering the `TAG_GEOMETRY_POINT` tag, it creates a `GeometryPoint` object using the tag's value (coordinates). +When encountering the `TAG_RANGE` tag, it creates a `Range` object using the tag's value (begin and end). + +### encode and decode Functions + +These are high-level functions for serializing and deserializing data: + +`encode(obj)`: Converts a Python object into CBOR binary format. +`decode(data)`: Converts CBOR binary data back into a Python object using the custom decoding logic. \ No newline at end of file diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py new file mode 100644 index 00000000..60cc1620 --- /dev/null +++ b/tests/unit/test_connection.py @@ -0,0 +1,92 @@ +import asyncio +import logging + +from unittest import IsolatedAsyncioTestCase, main +from unittest.mock import patch +import threading + +from surrealdb.connection import Connection, ResponseType, RequestData + + +class TestConnection(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.logger = logging.getLogger(__name__) + self.url: str = 'http://localhost:8000' + self.con = Connection(base_url=self.url, logger=self.logger) + + async def test___init__(self): + self.assertEqual(self.url, self.con._base_url) + self.assertEqual(self.logger, self.con._logger) + + # assert that the locks are of type threading.Lock + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.SEND].__class__) + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.NOTIFICATION].__class__) + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.ERROR].__class__) + + # assert that the queues are of type dict + self.assertEqual(dict(), self.con._queues[ResponseType.SEND]) + self.assertEqual(dict(), self.con._queues[ResponseType.NOTIFICATION]) + self.assertEqual(dict(), self.con._queues[ResponseType.ERROR]) + + async def test_use(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.use("test", "test") + message = str(context.exception) + self.assertEqual("use method must be implemented", message) + + async def test_connect(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.connect() + message = str(context.exception) + self.assertEqual("connect method must be implemented", message) + + async def test_close(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.close() + message = str(context.exception) + self.assertEqual("close method must be implemented", message) + + async def test__make_request(self): + request_data = RequestData(id="1", method="test", params=()) + with self.assertRaises(NotImplementedError) as context: + await self.con._make_request(request_data, encoder=lambda x: x, decoder=lambda x: x) + message = str(context.exception) + self.assertEqual("_make_request method must be implemented", message) + + async def test_set(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.set("test", "test") + message = str(context.exception) + self.assertEqual("set method must be implemented", message) + + async def test_unset(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.unset("test") + message = str(context.exception) + self.assertEqual("unset method must be implemented", message) + + async def test_set_token(self): + self.con.set_token("test") + self.assertEqual("test", self.con._auth_token) + + async def test_create_response_queue(self): + # get a queue when there are now queues in the dictionary + outcome = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(self.con._queues[1]["test"], outcome) + self.assertEqual(id(outcome), id(self.con._queues[1]["test"])) + + # get a queue when there are queues in the dictionary with the same queue_id + outcome_two = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertNotEqual(outcome, outcome_two) + self.assertNotEqual(id(outcome), id(outcome_two)) + + # get a queue when there are queues in the dictionary with different queue_id + outcome_three = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test_two") + self.assertEqual(self.con._queues[1]["test_two"], outcome_three) + self.assertEqual(id(outcome_three), id(self.con._queues[1]["test_two"])) + + + +if __name__ == '__main__': + main() diff --git a/tests/unit/test_ws_connection.py b/tests/unit/test_ws_connection.py index 57c0f23d..6dfb73fa 100644 --- a/tests/unit/test_ws_connection.py +++ b/tests/unit/test_ws_connection.py @@ -2,6 +2,7 @@ import logging from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch from surrealdb.connection_ws import WebsocketConnection @@ -12,6 +13,16 @@ async def asyncSetUp(self): self.ws_con = WebsocketConnection(base_url='ws://localhost:8000', logger=logger) await self.ws_con.connect() + async def test_one(self): + await self.ws_con.use("test", "test") + token = await self.ws_con.send('signin', {'user': 'root', 'pass': 'root'}) + await self.ws_con.unset("root") + print("Test set") + + + async def test_create_response_queue(self): + self.ws_con[1] = {} + async def test_send(self): await self.ws_con.use("test", "test") token = await self.ws_con.send('signin', {'user': 'root', 'pass': 'root'}) From 3abc0008f9097c94bffd6c147afd69ab8c7c70da Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 25 Nov 2024 08:20:34 +0000 Subject: [PATCH 2/5] adding mocks and finishing off the test --- surrealdb/connection.py | 38 ++++++++++++++++++++++++++++++--- tests/unit/test_connection.py | 40 +++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/surrealdb/connection.py b/surrealdb/connection.py index b5a98abe..0041fe1d 100644 --- a/surrealdb/connection.py +++ b/surrealdb/connection.py @@ -1,3 +1,6 @@ +""" +Defines the base Connection class for sending and receiving requests. +""" import logging import secrets import string @@ -188,14 +191,34 @@ def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None: if response_type_queues: return response_type_queues.get(queue_id) - def remove_response_queue(self, response_type: int, queue_id: str): + def remove_response_queue(self, response_type: int, queue_id: str) -> None: + """ + Remove a response queue for a given response type. + + Notes: + Does not alert if the key is missing + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) if response_type_queues: response_type_queues.pop(queue_id, None) - async def send(self, method: str, *params): + async def send(self, method: str, *params) -> dict: + """ + Sends a request to the server with a unique ID and returns the response. + + Args: + method (str): The method of the request. + params: Parameters for the request. + + Returns: + dict: The response data from the request. + """ request_data = RequestData( id=request_id(REQUEST_ID_LENGTH), method=method, params=params ) @@ -219,7 +242,16 @@ async def send(self, method: str, *params): ) raise e - async def live_notifications(self, live_query_id: uuid.UUID): + async def live_notifications(self, live_query_id: uuid.UUID) -> Queue: + """ + Create a response queue for live notifications by essentially creating a NOTIFICATION response queue. + + Args: + live_query_id (uuid.UUID): The unique identifier for the live query. + + Returns: + Queue: The response queue for the live notifications. + """ queue = self.create_response_queue( ResponseType.NOTIFICATION, str(live_query_id) ) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 60cc1620..ceae2b2b 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -1,11 +1,13 @@ -import asyncio +""" +Defines the unit tests for the Connection class. +""" import logging - -from unittest import IsolatedAsyncioTestCase, main -from unittest.mock import patch import threading +from unittest import IsolatedAsyncioTestCase, main +from unittest.mock import patch, AsyncMock, MagicMock from surrealdb.connection import Connection, ResponseType, RequestData +from surrealdb.data.cbor import encode, decode class TestConnection(IsolatedAsyncioTestCase): @@ -86,6 +88,36 @@ async def test_create_response_queue(self): self.assertEqual(self.con._queues[1]["test_two"], outcome_three) self.assertEqual(id(outcome_three), id(self.con._queues[1]["test_two"])) + async def test_remove_response_queue(self): + self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test_two") + self.assertEqual(len(self.con._queues[1].keys()), 2) + + self.con.remove_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(len(self.con._queues[1].keys()), 1) + + self.con.remove_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(len(self.con._queues[1].keys()), 1) + + @patch("surrealdb.connection.request_id") + @patch("surrealdb.connection.Connection._make_request", new_callable=AsyncMock) + async def test_send(self, mock__make_request, mock_request_id): + mock_logger = MagicMock() + response_data = {"result": "test"} + self.con._logger = mock_logger + mock__make_request.return_value = response_data + mock_request_id.return_value = "1" + + request_data = RequestData(id="1", method="test", params=("test",)) + result = await self.con.send("test", "test") + + self.assertEqual(response_data, result) + mock__make_request.assert_called_once_with( + request_data, + encoder=encode, + decoder=decode + ) + self.assertEqual(3, mock_logger.debug.call_count) if __name__ == '__main__': From c64ff9a97b598a1c9b26be40059f225cb31ed904 Mon Sep 17 00:00:00 2001 From: Remade Date: Mon, 25 Nov 2024 18:43:33 +0100 Subject: [PATCH 3/5] mypy issue fix --- surrealdb/async_surrealdb.py | 2 +- surrealdb/connection.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/surrealdb/async_surrealdb.py b/surrealdb/async_surrealdb.py index 5c9dea1b..6c6d7c11 100644 --- a/surrealdb/async_surrealdb.py +++ b/surrealdb/async_surrealdb.py @@ -276,4 +276,4 @@ async def kill(self, live_query_id: uuid.UUID) -> None: :param live_query_id: The UUID of the live query to kill. """ - return await self.__connection.send("kill", live_query_id) + await self.__connection.send("kill", live_query_id) diff --git a/surrealdb/connection.py b/surrealdb/connection.py index 8abdaf79..d3b09030 100644 --- a/surrealdb/connection.py +++ b/surrealdb/connection.py @@ -178,7 +178,8 @@ def create_response_queue(self, response_type: int, queue_id: str) -> Queue: queue: Queue = Queue(maxsize=0) response_type_queues[queue_id] = queue self._queues[response_type] = response_type_queues - return queue + + return queue def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None: """ @@ -195,8 +196,9 @@ def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None: lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) - if response_type_queues: - return response_type_queues.get(queue_id) + if not response_type_queues: + return None + return response_type_queues.get(queue_id) def remove_response_queue(self, response_type: int, queue_id: str) -> None: """ @@ -215,7 +217,7 @@ def remove_response_queue(self, response_type: int, queue_id: str) -> None: if response_type_queues: response_type_queues.pop(queue_id, None) - async def send(self, method: str, *params) -> dict: + async def send(self, method: str, *params): """ Sends a request to the server with a unique ID and returns the response. From bd450225650dcfbf509ed6a63b950e583b1d94d0 Mon Sep 17 00:00:00 2001 From: Remade Date: Fri, 29 Nov 2024 09:33:06 +0100 Subject: [PATCH 4/5] connection test fix --- tests/unit/test_connection.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index ceae2b2b..b7a3ee4c 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -15,7 +15,7 @@ class TestConnection(IsolatedAsyncioTestCase): async def asyncSetUp(self): self.logger = logging.getLogger(__name__) self.url: str = 'http://localhost:8000' - self.con = Connection(base_url=self.url, logger=self.logger) + self.con = Connection(base_url=self.url, logger=self.logger, encoder=encode, decoder=decode) async def test___init__(self): self.assertEqual(self.url, self.con._base_url) @@ -52,7 +52,7 @@ async def test_close(self): async def test__make_request(self): request_data = RequestData(id="1", method="test", params=()) with self.assertRaises(NotImplementedError) as context: - await self.con._make_request(request_data, encoder=lambda x: x, decoder=lambda x: x) + await self.con._make_request(request_data) message = str(context.exception) self.assertEqual("_make_request method must be implemented", message) @@ -112,11 +112,7 @@ async def test_send(self, mock__make_request, mock_request_id): result = await self.con.send("test", "test") self.assertEqual(response_data, result) - mock__make_request.assert_called_once_with( - request_data, - encoder=encode, - decoder=decode - ) + mock__make_request.assert_called_once_with(request_data) self.assertEqual(3, mock_logger.debug.call_count) From 39ab815e4ac512292bb3eea90abf76f77b9f98b7 Mon Sep 17 00:00:00 2001 From: Remade Date: Tue, 3 Dec 2024 06:36:36 +0100 Subject: [PATCH 5/5] PR fix --- surrealdb/connection_http.py | 4 +++- tests/unit/test_connection.py | 4 ++-- tests/unit/test_ws_connection.py | 3 --- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/surrealdb/connection_http.py b/surrealdb/connection_http.py index 9b3f1e1a..fd08b71d 100644 --- a/surrealdb/connection_http.py +++ b/surrealdb/connection_http.py @@ -51,7 +51,9 @@ def _prepare_query_method_params(self, params: Tuple) -> Tuple: async def _make_request(self, request_data: RequestData): if not self._is_ready: - raise SurrealDbConnectionError("connection not ready. Call the connect() method first") + raise SurrealDbConnectionError( + "connection not ready. Call the connect() method first" + ) if self._namespace is None: raise SurrealDbConnectionError("namespace not set") diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index a39635d8..b28a71ac 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -80,8 +80,8 @@ async def test_create_response_queue(self): # get a queue when there are queues in the dictionary with the same queue_id outcome_two = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") - self.assertNotEqual(outcome, outcome_two) - self.assertNotEqual(id(outcome), id(outcome_two)) + self.assertEqual(outcome, outcome_two) + self.assertEqual(id(outcome), id(outcome_two)) # get a queue when there are queues in the dictionary with different queue_id outcome_three = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test_two") diff --git a/tests/unit/test_ws_connection.py b/tests/unit/test_ws_connection.py index 9a485c80..82fbaf07 100644 --- a/tests/unit/test_ws_connection.py +++ b/tests/unit/test_ws_connection.py @@ -24,9 +24,6 @@ async def test_one(self): await self.ws_con.unset("root") print("Test set") - async def test_create_response_queue(self): - self.ws_con[1] = {} - async def test_send(self): await self.ws_con.use("test", "test") token = await self.ws_con.send('signin', {'user': 'root', 'pass': 'root'})