Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding unittest for base connection object #120

Merged
merged 10 commits into from
Dec 3, 2024
1 change: 1 addition & 0 deletions surrealdb/async_surrealdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,5 @@ async def kill(self, live_query_id: uuid.UUID) -> None:

:param live_query_id: The UUID of the live query to kill.
"""

return await self.__connection.send(METHOD_KILL, live_query_id)
177 changes: 157 additions & 20 deletions surrealdb/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""
Defines the base Connection class for sending and receiving requests.
"""

import logging
import secrets
import string
import logging
import threading
import uuid
from dataclasses import dataclass

from dataclasses import dataclass
from typing import Dict, Tuple
from surrealdb.constants import (
REQUEST_ID_LENGTH,
Expand All @@ -23,19 +27,50 @@


class ResponseType:
"""
Enum-like class representing response types for the connection.

Attributes:
SEND (int): Response type for standard requests.
NOTIFICATION (int): Response type for notifications.
ERROR (int): Response type for errors.
"""

SEND = 1
NOTIFICATION = 2
ERROR = 3


@dataclass
class RequestData:
"""
Represents the data for a request sent over the connection.

Attributes:
id (str): Unique identifier for the request.
method (str): The method name to invoke.
params (Tuple): Parameters for the method.
"""

id: str
method: str
params: Tuple


class Connection:
"""
Base class for managing a connection to the database.

Manages request/response lifecycle, including the use of queues for
handling asynchronous communication.

Attributes:
_queues (Dict[int, dict]): Mapping of response types to their queues.
_namespace (str | None): Current namespace in use.
_database (str | None): Current database in use.
_auth_token (str | None): Authentication token.
"""

_queues: Dict[int, dict]
_locks: Dict[int, threading.Lock]
_namespace: str | None = None
Expand All @@ -49,6 +84,15 @@ def __init__(
encoder,
decoder,
):
"""
Initialize the Connection instance.

Args:
base_url (str): The base URL of the server.
logger (logging.Logger): Logger for debugging and tracking activities.
encoder (function): Function to encode the request.
decoder (function): Function to decode the response.
"""
self._encoder = encoder
self._decoder = decoder

Expand All @@ -67,47 +111,120 @@ def __init__(
self._logger = logger

async def use(self, namespace: str, database: str) -> None:
pass
"""
Set the namespace and database for subsequent operations.

Args:
namespace (str): The namespace to use.
database (str): The database to use.
"""
raise NotImplementedError("use method must be implemented")

async def connect(self) -> None:
pass
"""
Establish a connection to the server.
"""
raise NotImplementedError("connect method must be implemented")

async def close(self) -> None:
pass
"""
Close the connection to the server.
"""
raise NotImplementedError("close method must be implemented")

async def _make_request(self, request_data: RequestData) -> dict:
"""
Internal method to send a request and handle the response.
Args:
request_data (RequestData): The data to send.
return:
dict: The response data from the request.
"""
raise NotImplementedError("_make_request method must be implemented")

async def _make_request(self, request_data: RequestData):
pass
async def set(self, key: str, value) -> None:
"""
Set a key-value pair in the database.

async def set(self, key: str, value):
pass
Args:
key (str): The key to set.
value: The value to set.
"""
raise NotImplementedError("set method must be implemented")

async def unset(self, key: str):
pass
async def unset(self, key: str) -> None:
"""
Unset a key-value pair in the database.

Args:
key (str): The key to unset.
"""
raise NotImplementedError("unset method must be implemented")

def set_token(self, token: str | None = None) -> None:
"""
Set the authentication token for the connection.

Args:
token (str): The authentication token to be set
"""
self._auth_token = token

def create_response_queue(self, response_type: int, queue_id: str):
def create_response_queue(self, response_type: int, queue_id: str) -> Queue:
"""
Create a response queue for a given response type.

Args:
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
queue_id (str): The unique identifier for the queue.
Returns:
Queue: The response queue for the given response type and queue ID
(existing queues will be overwritten if same ID is used, cannot get existing queue).
"""
lock = self._locks[response_type]
with lock:
response_type_queues = self._queues.get(response_type)
if response_type_queues is None:
response_type_queues = {}

if response_type_queues.get(queue_id) is None:
queue: Queue = Queue(maxsize=0)
queue = response_type_queues.get(queue_id)
if queue is None:
queue = Queue(maxsize=0)
response_type_queues[queue_id] = queue
self._queues[response_type] = response_type_queues
return queue

def get_response_queue(self, response_type: int, queue_id: str):
return queue

def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None:
"""
Get a response queue for a given response type.

Args:
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
queue_id (str): The unique identifier for the queue.

Returns:
Queue: The response queue for the given response type and queue ID
(existing queues will be overwritten if same ID is used).
"""
lock = self._locks[response_type]
with lock:
response_type_queues = self._queues.get(response_type)
if response_type_queues:
return response_type_queues.get(queue_id)
if not response_type_queues:
return None
return response_type_queues.get(queue_id)

def remove_response_queue(self, response_type: int, queue_id: str):
def remove_response_queue(self, response_type: int, queue_id: str) -> None:
"""
Remove a response queue for a given response type.

Notes:
Does not alert if the key is missing

Args:
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
queue_id (str): The unique identifier for the queue.
"""
lock = self._locks[response_type]
with lock:
response_type_queues = self._queues.get(response_type)
Expand All @@ -134,6 +251,17 @@ def _prepare_method_params(method: str, params) -> Tuple:
return prepared_params

async def send(self, method: str, *params):
"""
Sends a request to the server with a unique ID and returns the response.

Args:
method (str): The method of the request.
params: Parameters for the request.

Returns:
dict: The response data from the request.
"""

prepared_params = self._prepare_method_params(method, params)
request_data = RequestData(
id=request_id(REQUEST_ID_LENGTH), method=method, params=prepared_params
Expand All @@ -156,7 +284,16 @@ async def send(self, method: str, *params):
)
raise e

async def live_notifications(self, live_query_id: uuid.UUID):
async def live_notifications(self, live_query_id: uuid.UUID) -> Queue:
"""
Create a response queue for live notifications by essentially creating a NOTIFICATION response queue.

Args:
live_query_id (uuid.UUID): The unique identifier for the live query.

Returns:
Queue: The response queue for the live notifications.
"""
queue = self.create_response_queue(
ResponseType.NOTIFICATION, str(live_query_id)
)
Expand Down
2 changes: 1 addition & 1 deletion surrealdb/connection_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def connect(self):
self._lib.sr_free_string(c_err)

async def close(self):
pass
self._lib.sr_surreal_rpc_free(self._c_surreal_rpc)

async def use(self, namespace: str, database: str) -> None:
self._namespace = namespace
Expand Down
10 changes: 10 additions & 0 deletions surrealdb/connection_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class HTTPConnection(Connection):
_request_variables: dict[str, Any] = {}
_request_variables_lock = threading.Lock()
_is_ready: bool = False

async def use(self, namespace: str, database: str) -> None:
self._namespace = namespace
Expand All @@ -34,6 +35,10 @@ async def connect(self) -> None:
raise SurrealDbConnectionError(
"connection failed. check server is up and base url is correct"
)
self._is_ready = True

async def close(self):
self._is_ready = False

def _prepare_query_method_params(self, params: Tuple) -> Tuple:
query, variables = params
Expand All @@ -45,6 +50,11 @@ def _prepare_query_method_params(self, params: Tuple) -> Tuple:
return query, variables

async def _make_request(self, request_data: RequestData):
if not self._is_ready:
raise SurrealDbConnectionError(
"connection not ready. Call the connect() method first"
)

if self._namespace is None:
raise SurrealDbConnectionError("namespace not set")

Expand Down
15 changes: 14 additions & 1 deletion surrealdb/connection_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,23 @@ async def use(self, namespace: str, database: str) -> None:

await self.send(METHOD_USE, namespace, database)

async def set(self, key: str, value):
async def set(self, key: str, value) -> None:
"""
Set a key-value pair in the database.

Args:
key (str): The key to set.
value: The value to set.
"""
await self.send(METHOD_SET, key, value)

async def unset(self, key: str):
"""
Unset a key-value pair in the database.

Args:
key (str): The key to unset.
"""
await self.send(METHOD_UNSET, key)

async def close(self):
Expand Down
43 changes: 43 additions & 0 deletions surrealdb/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## What is CBOR?
CBOR is a binary data serialization format similar to JSON but more compact, efficient, and capable of encoding a
broader range of data types. It is useful for exchanging structured data between systems, especially when performance
and size are critical.

## Purpose of the CBOR Implementation

The CBOR code here allows the custom SurrealDB types (e.g., `GeometryPoint`, `Table`, `Range`, etc.) to be serialized
into CBOR binary format and deserialized back into Python objects. This is necessary because these types are not natively
supported by CBOR; thus, custom encoding and decoding logic is implemented.

## Key Components

### Custom Types

`Range` Class: Represents a range with a beginning (`begin`) and end (`end`). These can either be included (`BoundIncluded`) or excluded (`BoundExcluded`).
`Table`, `RecordID`, `GeometryPoint`, etc.: Custom SurrealDB-specific data types, representing domain-specific constructs like tables, records, and geometrical objects.

### CBOR Encoder

The function `default_encoder` is used to encode custom Python objects into CBOR's binary format. This is done by associating a specific CBOR tag (a numeric identifier) with each data type.

For example:

`GeometryPoint` objects are encoded using the tag `TAG_GEOMETRY_POINT` with its coordinates as the value.
`Range` objects are encoded using the tag `TAG_BOUND_EXCLUDED` with a list [begin, end] as its value.
The `CBORTag` class is used to represent tagged data in `CBOR`.

### CBOR Decoder

The function `tag_decoder` is the inverse of `default_encoder`. It takes tagged CBOR data and reconstructs the corresponding Python objects.

For example:

When encountering the `TAG_GEOMETRY_POINT` tag, it creates a `GeometryPoint` object using the tag's value (coordinates).
When encountering the `TAG_RANGE` tag, it creates a `Range` object using the tag's value (begin and end).

### encode and decode Functions

These are high-level functions for serializing and deserializing data:

`encode(obj)`: Converts a Python object into CBOR binary format.
`decode(data)`: Converts CBOR binary data back into a Python object using the custom decoding logic.
Loading
Loading