Skip to content

Commit

Permalink
1003 Support arrays of timestamp / timestamptz / date / time in SQLite (
Browse files Browse the repository at this point in the history
#1004)

* support date / time / timestamp arrays in SQLite

* don't run tests for cockroachdb for now
  • Loading branch information
dantownsend authored May 31, 2024
1 parent b1de70a commit 9f1a4d8
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 64 deletions.
31 changes: 26 additions & 5 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,14 @@ def column_type(self):
if engine_type in ("postgres", "cockroach"):
return f"{self.base_column.column_type}[]"
elif engine_type == "sqlite":
return "ARRAY"
inner_column = self._get_inner_column()
return (
f"ARRAY_{inner_column.column_type}"
if isinstance(
inner_column, (Date, Timestamp, Timestamptz, Time)
)
else "ARRAY"
)
raise Exception("Unrecognized engine type")

def _setup_base_column(self, table_class: t.Type[Table]):
Expand Down Expand Up @@ -2564,6 +2571,23 @@ def _get_dimensions(self, start: int = 0) -> int:
else:
return start + 1

def _get_inner_column(self) -> Column:
"""
A helper function to get the innermost ``Column`` for the array. For
example::
>>> Array(Varchar())._get_inner_column()
Varchar
>>> Array(Array(Varchar()))._get_inner_column()
Varchar
"""
if isinstance(self.base_column, Array):
return self.base_column._get_inner_column()
else:
return self.base_column

def _get_inner_value_type(self) -> t.Type:
"""
A helper function to get the innermost value type for the array. For
Expand All @@ -2576,10 +2600,7 @@ def _get_inner_value_type(self) -> t.Type:
str
"""
if isinstance(self.base_column, Array):
return self.base_column._get_inner_value_type()
else:
return self.base_column.value_type
return self._get_inner_column().value_type

def __getitem__(self, value: int) -> Array:
"""
Expand Down
228 changes: 171 additions & 57 deletions piccolo/engine/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uuid
from dataclasses import dataclass
from decimal import Decimal
from functools import partial, wraps

from piccolo.engine.base import Batch, Engine, validate_savepoint_name
from piccolo.engine.exceptions import TransactionError
Expand All @@ -35,14 +36,14 @@
# In


def convert_numeric_in(value):
def convert_numeric_in(value: Decimal) -> float:
"""
Convert any Decimal values into floats.
"""
return float(value)


def convert_uuid_in(value) -> str:
def convert_uuid_in(value: uuid.UUID) -> str:
"""
Converts the UUID value being passed into sqlite.
"""
Expand All @@ -56,7 +57,7 @@ def convert_time_in(value: datetime.time) -> str:
return value.isoformat()


def convert_date_in(value: datetime.date):
def convert_date_in(value: datetime.date) -> str:
"""
Converts the date value being passed into sqlite.
"""
Expand All @@ -74,122 +75,235 @@ def convert_datetime_in(value: datetime.datetime) -> str:
return str(value)


def convert_timedelta_in(value: datetime.timedelta):
def convert_timedelta_in(value: datetime.timedelta) -> float:
"""
Converts the timedelta value being passed into sqlite.
"""
return value.total_seconds()


def convert_array_in(value: list):
def convert_array_in(value: list) -> str:
"""
Converts a list value into a string.
Converts a list value into a string (it handles nested lists, and type like
dateime/ time / date which aren't usually JSON serialisable.).
"""
if value and type(value[0]) not in [str, int, float, list]:
raise ValueError("Can only serialise str, int, float, and list.")

return dump_json(value)
def serialise(data: list):
output = []

for item in data:
if isinstance(item, list):
output.append(serialise(item))
elif isinstance(
item, (datetime.datetime, datetime.time, datetime.date)
):
if adapter := ADAPTERS.get(type(item)):
output.append(adapter(item))
else:
raise ValueError("The adapter wasn't found.")
elif item is None or isinstance(item, (str, int, float, list)):
# We can safely JSON serialise these.
output.append(item)
else:
raise ValueError("We can't currently serialise this value.")

return output

return dump_json(serialise(value))


###############################################################################

# Register adapters

ADAPTERS: t.Dict[t.Type, t.Callable[[t.Any], t.Any]] = {
Decimal: convert_numeric_in,
uuid.UUID: convert_uuid_in,
datetime.time: convert_time_in,
datetime.date: convert_date_in,
datetime.datetime: convert_datetime_in,
datetime.timedelta: convert_timedelta_in,
list: convert_array_in,
}

for value_type, adapter in ADAPTERS.items():
sqlite3.register_adapter(value_type, adapter)

###############################################################################

# Out


def convert_numeric_out(value: bytes) -> Decimal:
def decode_to_string(converter: t.Callable[[str], t.Any]):
"""
This means we can use our converters with string and bytes. They are
passed bytes when used directly via SQLite, and are passed strings when
used by the array converters.
"""

@wraps(converter)
def wrapper(value: t.Union[str, bytes]) -> t.Any:
if isinstance(value, bytes):
return converter(value.decode("utf8"))
elif isinstance(value, str):
return converter(value)
else:
raise ValueError("Unsupported type")

return wrapper


@decode_to_string
def convert_numeric_out(value: str) -> Decimal:
"""
Convert float values into Decimals.
"""
return Decimal(value.decode("ascii"))
return Decimal(value)


def convert_int_out(value: bytes) -> int:
@decode_to_string
def convert_int_out(value: str) -> int:
"""
Make sure Integer values are actually of type int.
"""
return int(float(value))


def convert_uuid_out(value: bytes) -> uuid.UUID:
@decode_to_string
def convert_uuid_out(value: str) -> uuid.UUID:
"""
If the value is a uuid, convert it to a UUID instance.
"""
return uuid.UUID(value.decode("utf8"))
return uuid.UUID(value)


def convert_date_out(value: bytes) -> datetime.date:
return datetime.date.fromisoformat(value.decode("utf8"))
@decode_to_string
def convert_date_out(value: str) -> datetime.date:
return datetime.date.fromisoformat(value)


def convert_time_out(value: bytes) -> datetime.time:
@decode_to_string
def convert_time_out(value: str) -> datetime.time:
"""
If the value is a time, convert it to a UUID instance.
"""
return datetime.time.fromisoformat(value.decode("utf8"))
return datetime.time.fromisoformat(value)


def convert_seconds_out(value: bytes) -> datetime.timedelta:
@decode_to_string
def convert_seconds_out(value: str) -> datetime.timedelta:
"""
If the value is from a seconds column, convert it to a timedelta instance.
"""
return datetime.timedelta(seconds=float(value.decode("utf8")))
return datetime.timedelta(seconds=float(value))


def convert_boolean_out(value: bytes) -> bool:
@decode_to_string
def convert_boolean_out(value: str) -> bool:
"""
If the value is from a boolean column, convert it to a bool value.
"""
_value = value.decode("utf8")
return _value == "1"
return value == "1"


def convert_timestamp_out(value: bytes) -> datetime.datetime:
@decode_to_string
def convert_timestamp_out(value: str) -> datetime.datetime:
"""
If the value is from a timestamp column, convert it to a datetime value.
"""
return datetime.datetime.fromisoformat(value.decode("utf8"))
return datetime.datetime.fromisoformat(value)


def convert_timestamptz_out(value: bytes) -> datetime.datetime:
@decode_to_string
def convert_timestamptz_out(value: str) -> datetime.datetime:
"""
If the value is from a timestamptz column, convert it to a datetime value,
with a timezone of UTC.
"""
_value = datetime.datetime.fromisoformat(value.decode("utf8"))
_value = _value.replace(tzinfo=datetime.timezone.utc)
return _value
return datetime.datetime.fromisoformat(value).replace(
tzinfo=datetime.timezone.utc
)


def convert_array_out(value: bytes) -> t.List:
@decode_to_string
def convert_array_out(value: str) -> t.List:
"""
If the value if from an array column, deserialise the string back into a
list.
"""
return load_json(value.decode("utf8"))


def convert_M2M_out(value: bytes) -> t.List:
_value = value.decode("utf8")
return _value.split(",")


sqlite3.register_converter("Numeric", convert_numeric_out)
sqlite3.register_converter("Integer", convert_int_out)
sqlite3.register_converter("UUID", convert_uuid_out)
sqlite3.register_converter("Date", convert_date_out)
sqlite3.register_converter("Time", convert_time_out)
sqlite3.register_converter("Seconds", convert_seconds_out)
sqlite3.register_converter("Boolean", convert_boolean_out)
sqlite3.register_converter("Timestamp", convert_timestamp_out)
sqlite3.register_converter("Timestamptz", convert_timestamptz_out)
sqlite3.register_converter("Array", convert_array_out)
sqlite3.register_converter("M2M", convert_M2M_out)

sqlite3.register_adapter(Decimal, convert_numeric_in)
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
sqlite3.register_adapter(datetime.time, convert_time_in)
sqlite3.register_adapter(datetime.date, convert_date_in)
sqlite3.register_adapter(datetime.datetime, convert_datetime_in)
sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in)
sqlite3.register_adapter(list, convert_array_in)
return load_json(value)


def convert_complex_array_out(value: bytes, converter: t.Callable):
"""
This is used to handle arrays of things like timestamps, which we can't
just load from JSON without doing additional work to convert the elements
back into Python objects.
"""
parsed = load_json(value.decode("utf8"))

def convert_list(list_value: t.List):
output = []

for value in list_value:
if isinstance(value, list):
# For nested arrays
output.append(convert_list(value))
elif isinstance(value, str):
output.append(converter(value))
else:
output.append(value)

return output

if isinstance(parsed, list):
return convert_list(parsed)
else:
return parsed


@decode_to_string
def convert_M2M_out(value: str) -> t.List:
return value.split(",")


###############################################################################
# Register the basic converters

CONVERTERS = {
"NUMERIC": convert_numeric_out,
"INTEGER": convert_int_out,
"UUID": convert_uuid_out,
"DATE": convert_date_out,
"TIME": convert_time_out,
"SECONDS": convert_seconds_out,
"BOOLEAN": convert_boolean_out,
"TIMESTAMP": convert_timestamp_out,
"TIMESTAMPTZ": convert_timestamptz_out,
"M2M": convert_M2M_out,
}

for column_name, converter in CONVERTERS.items():
sqlite3.register_converter(column_name, converter)

###############################################################################
# Register the array converters

# The ARRAY column type handles values which can be easily serialised to and
# from JSON.
sqlite3.register_converter("ARRAY", convert_array_out)

# We have special column types for arrays of timestamps etc, as simply loading
# the JSON isn't sufficient.
for column_name in ("TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME"):
sqlite3.register_converter(
f"ARRAY_{column_name}",
partial(
convert_complex_array_out,
converter=CONVERTERS[column_name],
),
)

###############################################################################

Expand Down
Loading

0 comments on commit 9f1a4d8

Please sign in to comment.