diff --git a/surrealdb/async_surrealdb.py b/surrealdb/async_surrealdb.py index ecffb0e3..c24647ac 100644 --- a/surrealdb/async_surrealdb.py +++ b/surrealdb/async_surrealdb.py @@ -295,4 +295,5 @@ async def kill(self, live_query_id: uuid.UUID) -> None: :param live_query_id: The UUID of the live query to kill. """ + return await self.__connection.send(METHOD_KILL, live_query_id) diff --git a/surrealdb/connection.py b/surrealdb/connection.py index 8bd0d32a..94bb58ae 100644 --- a/surrealdb/connection.py +++ b/surrealdb/connection.py @@ -1,10 +1,14 @@ +""" +Defines the base Connection class for sending and receiving requests. +""" + +import logging import secrets import string -import logging import threading import uuid -from dataclasses import dataclass +from dataclasses import dataclass from typing import Dict, Tuple from surrealdb.constants import ( REQUEST_ID_LENGTH, @@ -23,6 +27,15 @@ class ResponseType: + """ + Enum-like class representing response types for the connection. + + Attributes: + SEND (int): Response type for standard requests. + NOTIFICATION (int): Response type for notifications. + ERROR (int): Response type for errors. + """ + SEND = 1 NOTIFICATION = 2 ERROR = 3 @@ -30,12 +43,34 @@ class ResponseType: @dataclass class RequestData: + """ + Represents the data for a request sent over the connection. + + Attributes: + id (str): Unique identifier for the request. + method (str): The method name to invoke. + params (Tuple): Parameters for the method. + """ + id: str method: str params: Tuple class Connection: + """ + Base class for managing a connection to the database. + + Manages request/response lifecycle, including the use of queues for + handling asynchronous communication. + + Attributes: + _queues (Dict[int, dict]): Mapping of response types to their queues. + _namespace (str | None): Current namespace in use. + _database (str | None): Current database in use. + _auth_token (str | None): Authentication token. + """ + _queues: Dict[int, dict] _locks: Dict[int, threading.Lock] _namespace: str | None = None @@ -49,6 +84,15 @@ def __init__( encoder, decoder, ): + """ + Initialize the Connection instance. + + Args: + base_url (str): The base URL of the server. + logger (logging.Logger): Logger for debugging and tracking activities. + encoder (function): Function to encode the request. + decoder (function): Function to decode the response. + """ self._encoder = encoder self._decoder = decoder @@ -67,47 +111,120 @@ def __init__( self._logger = logger async def use(self, namespace: str, database: str) -> None: - pass + """ + Set the namespace and database for subsequent operations. + + Args: + namespace (str): The namespace to use. + database (str): The database to use. + """ + raise NotImplementedError("use method must be implemented") async def connect(self) -> None: - pass + """ + Establish a connection to the server. + """ + raise NotImplementedError("connect method must be implemented") async def close(self) -> None: - pass + """ + Close the connection to the server. + """ + raise NotImplementedError("close method must be implemented") + + async def _make_request(self, request_data: RequestData) -> dict: + """ + Internal method to send a request and handle the response. + Args: + request_data (RequestData): The data to send. + return: + dict: The response data from the request. + """ + raise NotImplementedError("_make_request method must be implemented") - async def _make_request(self, request_data: RequestData): - pass + async def set(self, key: str, value) -> None: + """ + Set a key-value pair in the database. - async def set(self, key: str, value): - pass + Args: + key (str): The key to set. + value: The value to set. + """ + raise NotImplementedError("set method must be implemented") - async def unset(self, key: str): - pass + async def unset(self, key: str) -> None: + """ + Unset a key-value pair in the database. + + Args: + key (str): The key to unset. + """ + raise NotImplementedError("unset method must be implemented") def set_token(self, token: str | None = None) -> None: + """ + Set the authentication token for the connection. + + Args: + token (str): The authentication token to be set + """ self._auth_token = token - def create_response_queue(self, response_type: int, queue_id: str): + def create_response_queue(self, response_type: int, queue_id: str) -> Queue: + """ + Create a response queue for a given response type. + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + Returns: + Queue: The response queue for the given response type and queue ID + (existing queues will be overwritten if same ID is used, cannot get existing queue). + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) if response_type_queues is None: response_type_queues = {} - if response_type_queues.get(queue_id) is None: - queue: Queue = Queue(maxsize=0) + queue = response_type_queues.get(queue_id) + if queue is None: + queue = Queue(maxsize=0) response_type_queues[queue_id] = queue self._queues[response_type] = response_type_queues - return queue - def get_response_queue(self, response_type: int, queue_id: str): + return queue + + def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None: + """ + Get a response queue for a given response type. + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + + Returns: + Queue: The response queue for the given response type and queue ID + (existing queues will be overwritten if same ID is used). + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) - if response_type_queues: - return response_type_queues.get(queue_id) + if not response_type_queues: + return None + return response_type_queues.get(queue_id) - def remove_response_queue(self, response_type: int, queue_id: str): + def remove_response_queue(self, response_type: int, queue_id: str) -> None: + """ + Remove a response queue for a given response type. + + Notes: + Does not alert if the key is missing + + Args: + response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR). + queue_id (str): The unique identifier for the queue. + """ lock = self._locks[response_type] with lock: response_type_queues = self._queues.get(response_type) @@ -134,6 +251,17 @@ def _prepare_method_params(method: str, params) -> Tuple: return prepared_params async def send(self, method: str, *params): + """ + Sends a request to the server with a unique ID and returns the response. + + Args: + method (str): The method of the request. + params: Parameters for the request. + + Returns: + dict: The response data from the request. + """ + prepared_params = self._prepare_method_params(method, params) request_data = RequestData( id=request_id(REQUEST_ID_LENGTH), method=method, params=prepared_params @@ -156,7 +284,16 @@ async def send(self, method: str, *params): ) raise e - async def live_notifications(self, live_query_id: uuid.UUID): + async def live_notifications(self, live_query_id: uuid.UUID) -> Queue: + """ + Create a response queue for live notifications by essentially creating a NOTIFICATION response queue. + + Args: + live_query_id (uuid.UUID): The unique identifier for the live query. + + Returns: + Queue: The response queue for the live notifications. + """ queue = self.create_response_queue( ResponseType.NOTIFICATION, str(live_query_id) ) diff --git a/surrealdb/connection_clib.py b/surrealdb/connection_clib.py index 35683bc6..cdb11a5d 100644 --- a/surrealdb/connection_clib.py +++ b/surrealdb/connection_clib.py @@ -180,7 +180,7 @@ async def connect(self): self._lib.sr_free_string(c_err) async def close(self): - pass + self._lib.sr_surreal_rpc_free(self._c_surreal_rpc) async def use(self, namespace: str, database: str) -> None: self._namespace = namespace diff --git a/surrealdb/connection_http.py b/surrealdb/connection_http.py index cfa70345..fd08b71d 100644 --- a/surrealdb/connection_http.py +++ b/surrealdb/connection_http.py @@ -10,6 +10,7 @@ class HTTPConnection(Connection): _request_variables: dict[str, Any] = {} _request_variables_lock = threading.Lock() + _is_ready: bool = False async def use(self, namespace: str, database: str) -> None: self._namespace = namespace @@ -34,6 +35,10 @@ async def connect(self) -> None: raise SurrealDbConnectionError( "connection failed. check server is up and base url is correct" ) + self._is_ready = True + + async def close(self): + self._is_ready = False def _prepare_query_method_params(self, params: Tuple) -> Tuple: query, variables = params @@ -45,6 +50,11 @@ def _prepare_query_method_params(self, params: Tuple) -> Tuple: return query, variables async def _make_request(self, request_data: RequestData): + if not self._is_ready: + raise SurrealDbConnectionError( + "connection not ready. Call the connect() method first" + ) + if self._namespace is None: raise SurrealDbConnectionError("namespace not set") diff --git a/surrealdb/connection_ws.py b/surrealdb/connection_ws.py index 07beb325..9abf1f01 100644 --- a/surrealdb/connection_ws.py +++ b/surrealdb/connection_ws.py @@ -28,10 +28,23 @@ async def use(self, namespace: str, database: str) -> None: await self.send(METHOD_USE, namespace, database) - async def set(self, key: str, value): + async def set(self, key: str, value) -> None: + """ + Set a key-value pair in the database. + + Args: + key (str): The key to set. + value: The value to set. + """ await self.send(METHOD_SET, key, value) async def unset(self, key: str): + """ + Unset a key-value pair in the database. + + Args: + key (str): The key to unset. + """ await self.send(METHOD_UNSET, key) async def close(self): diff --git a/surrealdb/data/README.md b/surrealdb/data/README.md new file mode 100644 index 00000000..e6451294 --- /dev/null +++ b/surrealdb/data/README.md @@ -0,0 +1,43 @@ +## What is CBOR? +CBOR is a binary data serialization format similar to JSON but more compact, efficient, and capable of encoding a +broader range of data types. It is useful for exchanging structured data between systems, especially when performance +and size are critical. + +## Purpose of the CBOR Implementation + +The CBOR code here allows the custom SurrealDB types (e.g., `GeometryPoint`, `Table`, `Range`, etc.) to be serialized +into CBOR binary format and deserialized back into Python objects. This is necessary because these types are not natively +supported by CBOR; thus, custom encoding and decoding logic is implemented. + +## Key Components + +### Custom Types + +`Range` Class: Represents a range with a beginning (`begin`) and end (`end`). These can either be included (`BoundIncluded`) or excluded (`BoundExcluded`). +`Table`, `RecordID`, `GeometryPoint`, etc.: Custom SurrealDB-specific data types, representing domain-specific constructs like tables, records, and geometrical objects. + +### CBOR Encoder + +The function `default_encoder` is used to encode custom Python objects into CBOR's binary format. This is done by associating a specific CBOR tag (a numeric identifier) with each data type. + +For example: + +`GeometryPoint` objects are encoded using the tag `TAG_GEOMETRY_POINT` with its coordinates as the value. +`Range` objects are encoded using the tag `TAG_BOUND_EXCLUDED` with a list [begin, end] as its value. +The `CBORTag` class is used to represent tagged data in `CBOR`. + +### CBOR Decoder + +The function `tag_decoder` is the inverse of `default_encoder`. It takes tagged CBOR data and reconstructs the corresponding Python objects. + +For example: + +When encountering the `TAG_GEOMETRY_POINT` tag, it creates a `GeometryPoint` object using the tag's value (coordinates). +When encountering the `TAG_RANGE` tag, it creates a `Range` object using the tag's value (begin and end). + +### encode and decode Functions + +These are high-level functions for serializing and deserializing data: + +`encode(obj)`: Converts a Python object into CBOR binary format. +`decode(data)`: Converts CBOR binary data back into a Python object using the custom decoding logic. \ No newline at end of file diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py new file mode 100644 index 00000000..b28a71ac --- /dev/null +++ b/tests/unit/test_connection.py @@ -0,0 +1,120 @@ +""" +Defines the unit tests for the Connection class. +""" +import logging +import threading +from unittest import IsolatedAsyncioTestCase, main +from unittest.mock import patch, AsyncMock, MagicMock + +from surrealdb.connection import Connection, ResponseType, RequestData +from surrealdb.data.cbor import encode, decode + + +class TestConnection(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.logger = logging.getLogger(__name__) + self.url: str = 'http://localhost:8000' + self.con = Connection(base_url=self.url, logger=self.logger, encoder=encode, decoder=decode) + + async def test___init__(self): + self.assertEqual(self.url, self.con._base_url) + self.assertEqual(self.logger, self.con._logger) + + # assert that the locks are of type threading.Lock + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.SEND].__class__) + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.NOTIFICATION].__class__) + self.assertEqual(type(threading.Lock()), self.con._locks[ResponseType.ERROR].__class__) + + # assert that the queues are of type dict + self.assertEqual(dict(), self.con._queues[ResponseType.SEND]) + self.assertEqual(dict(), self.con._queues[ResponseType.NOTIFICATION]) + self.assertEqual(dict(), self.con._queues[ResponseType.ERROR]) + + async def test_use(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.use("test", "test") + message = str(context.exception) + self.assertEqual("use method must be implemented", message) + + async def test_connect(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.connect() + message = str(context.exception) + self.assertEqual("connect method must be implemented", message) + + async def test_close(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.close() + message = str(context.exception) + self.assertEqual("close method must be implemented", message) + + async def test__make_request(self): + request_data = RequestData(id="1", method="test", params=()) + with self.assertRaises(NotImplementedError) as context: + await self.con._make_request(request_data) + message = str(context.exception) + self.assertEqual("_make_request method must be implemented", message) + + async def test_set(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.set("test", "test") + message = str(context.exception) + self.assertEqual("set method must be implemented", message) + + async def test_unset(self): + with self.assertRaises(NotImplementedError) as context: + await self.con.unset("test") + message = str(context.exception) + self.assertEqual("unset method must be implemented", message) + + async def test_set_token(self): + self.con.set_token("test") + self.assertEqual("test", self.con._auth_token) + + async def test_create_response_queue(self): + # get a queue when there are now queues in the dictionary + outcome = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(self.con._queues[ResponseType.SEND]["test"], outcome) + self.assertEqual(id(outcome), id(self.con._queues[ResponseType.SEND]["test"])) + + # get a queue when there are queues in the dictionary with the same queue_id + outcome_two = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(outcome, outcome_two) + self.assertEqual(id(outcome), id(outcome_two)) + + # get a queue when there are queues in the dictionary with different queue_id + outcome_three = self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test_two") + self.assertEqual(self.con._queues[1]["test_two"], outcome_three) + self.assertEqual(id(outcome_three), id(self.con._queues[1]["test_two"])) + + async def test_remove_response_queue(self): + self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.con.create_response_queue(response_type=ResponseType.SEND, queue_id="test_two") + self.assertEqual(len(self.con._queues[1].keys()), 2) + + self.con.remove_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(len(self.con._queues[1].keys()), 1) + + self.con.remove_response_queue(response_type=ResponseType.SEND, queue_id="test") + self.assertEqual(len(self.con._queues[1].keys()), 1) + + @patch("surrealdb.connection.request_id") + @patch("surrealdb.connection.Connection._make_request", new_callable=AsyncMock) + async def test_send(self, mock__make_request, mock_request_id): + mock_logger = MagicMock() + response_data = {"result": "test"} + self.con._logger = mock_logger + mock__make_request.return_value = response_data + mock_request_id.return_value = "1" + + request_data = RequestData(id="1", method="test", params=("test",)) + result = await self.con.send("test", "test") + + self.assertEqual(response_data, result) + mock__make_request.assert_called_once_with(request_data) + self.assertEqual(3, mock_logger.debug.call_count) + + +if __name__ == '__main__': + main() diff --git a/tests/unit/test_ws_connection.py b/tests/unit/test_ws_connection.py index 35614c22..82fbaf07 100644 --- a/tests/unit/test_ws_connection.py +++ b/tests/unit/test_ws_connection.py @@ -2,6 +2,7 @@ import logging from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch from surrealdb.connection_ws import WebsocketConnection from surrealdb.data.cbor import encode, decode @@ -17,6 +18,12 @@ async def asyncTearDown(self): await self.ws_con.send("query", "DELETE users;") await self.ws_con.close() + async def test_one(self): + await self.ws_con.use("test", "test") + token = await self.ws_con.send('signin', {'user': 'root', 'pass': 'root'}) + await self.ws_con.unset("root") + print("Test set") + async def test_send(self): await self.ws_con.use("test", "test") token = await self.ws_con.send('signin', {'user': 'root', 'pass': 'root'})