From 9f1a4d8d1e61d2d6b1b100fe0cd64ea44261ed25 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 31 May 2024 12:32:11 +0100 Subject: [PATCH] 1003 Support arrays of timestamp / timestamptz / date / time in SQLite (#1004) * support date / time / timestamp arrays in SQLite * don't run tests for cockroachdb for now --- piccolo/columns/column_types.py | 31 ++++- piccolo/engine/sqlite.py | 228 ++++++++++++++++++++++++-------- tests/columns/test_array.py | 91 ++++++++++++- 3 files changed, 286 insertions(+), 64 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index add0c6f5c..e16ffe6ab 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -2532,7 +2532,14 @@ def column_type(self): if engine_type in ("postgres", "cockroach"): return f"{self.base_column.column_type}[]" elif engine_type == "sqlite": - return "ARRAY" + inner_column = self._get_inner_column() + return ( + f"ARRAY_{inner_column.column_type}" + if isinstance( + inner_column, (Date, Timestamp, Timestamptz, Time) + ) + else "ARRAY" + ) raise Exception("Unrecognized engine type") def _setup_base_column(self, table_class: t.Type[Table]): @@ -2564,6 +2571,23 @@ def _get_dimensions(self, start: int = 0) -> int: else: return start + 1 + def _get_inner_column(self) -> Column: + """ + A helper function to get the innermost ``Column`` for the array. For + example:: + + >>> Array(Varchar())._get_inner_column() + Varchar + + >>> Array(Array(Varchar()))._get_inner_column() + Varchar + + """ + if isinstance(self.base_column, Array): + return self.base_column._get_inner_column() + else: + return self.base_column + def _get_inner_value_type(self) -> t.Type: """ A helper function to get the innermost value type for the array. For @@ -2576,10 +2600,7 @@ def _get_inner_value_type(self) -> t.Type: str """ - if isinstance(self.base_column, Array): - return self.base_column._get_inner_value_type() - else: - return self.base_column.value_type + return self._get_inner_column().value_type def __getitem__(self, value: int) -> Array: """ diff --git a/piccolo/engine/sqlite.py b/piccolo/engine/sqlite.py index 862afaa71..f6fbd4e38 100644 --- a/piccolo/engine/sqlite.py +++ b/piccolo/engine/sqlite.py @@ -9,6 +9,7 @@ import uuid from dataclasses import dataclass from decimal import Decimal +from functools import partial, wraps from piccolo.engine.base import Batch, Engine, validate_savepoint_name from piccolo.engine.exceptions import TransactionError @@ -35,14 +36,14 @@ # In -def convert_numeric_in(value): +def convert_numeric_in(value: Decimal) -> float: """ Convert any Decimal values into floats. """ return float(value) -def convert_uuid_in(value) -> str: +def convert_uuid_in(value: uuid.UUID) -> str: """ Converts the UUID value being passed into sqlite. """ @@ -56,7 +57,7 @@ def convert_time_in(value: datetime.time) -> str: return value.isoformat() -def convert_date_in(value: datetime.date): +def convert_date_in(value: datetime.date) -> str: """ Converts the date value being passed into sqlite. """ @@ -74,122 +75,235 @@ def convert_datetime_in(value: datetime.datetime) -> str: return str(value) -def convert_timedelta_in(value: datetime.timedelta): +def convert_timedelta_in(value: datetime.timedelta) -> float: """ Converts the timedelta value being passed into sqlite. """ return value.total_seconds() -def convert_array_in(value: list): +def convert_array_in(value: list) -> str: """ - Converts a list value into a string. + Converts a list value into a string (it handles nested lists, and type like + dateime/ time / date which aren't usually JSON serialisable.). + """ - if value and type(value[0]) not in [str, int, float, list]: - raise ValueError("Can only serialise str, int, float, and list.") - return dump_json(value) + def serialise(data: list): + output = [] + + for item in data: + if isinstance(item, list): + output.append(serialise(item)) + elif isinstance( + item, (datetime.datetime, datetime.time, datetime.date) + ): + if adapter := ADAPTERS.get(type(item)): + output.append(adapter(item)) + else: + raise ValueError("The adapter wasn't found.") + elif item is None or isinstance(item, (str, int, float, list)): + # We can safely JSON serialise these. + output.append(item) + else: + raise ValueError("We can't currently serialise this value.") + + return output + + return dump_json(serialise(value)) + + +############################################################################### + +# Register adapters + +ADAPTERS: t.Dict[t.Type, t.Callable[[t.Any], t.Any]] = { + Decimal: convert_numeric_in, + uuid.UUID: convert_uuid_in, + datetime.time: convert_time_in, + datetime.date: convert_date_in, + datetime.datetime: convert_datetime_in, + datetime.timedelta: convert_timedelta_in, + list: convert_array_in, +} +for value_type, adapter in ADAPTERS.items(): + sqlite3.register_adapter(value_type, adapter) + +############################################################################### # Out -def convert_numeric_out(value: bytes) -> Decimal: +def decode_to_string(converter: t.Callable[[str], t.Any]): + """ + This means we can use our converters with string and bytes. They are + passed bytes when used directly via SQLite, and are passed strings when + used by the array converters. + """ + + @wraps(converter) + def wrapper(value: t.Union[str, bytes]) -> t.Any: + if isinstance(value, bytes): + return converter(value.decode("utf8")) + elif isinstance(value, str): + return converter(value) + else: + raise ValueError("Unsupported type") + + return wrapper + + +@decode_to_string +def convert_numeric_out(value: str) -> Decimal: """ Convert float values into Decimals. """ - return Decimal(value.decode("ascii")) + return Decimal(value) -def convert_int_out(value: bytes) -> int: +@decode_to_string +def convert_int_out(value: str) -> int: """ Make sure Integer values are actually of type int. """ return int(float(value)) -def convert_uuid_out(value: bytes) -> uuid.UUID: +@decode_to_string +def convert_uuid_out(value: str) -> uuid.UUID: """ If the value is a uuid, convert it to a UUID instance. """ - return uuid.UUID(value.decode("utf8")) + return uuid.UUID(value) -def convert_date_out(value: bytes) -> datetime.date: - return datetime.date.fromisoformat(value.decode("utf8")) +@decode_to_string +def convert_date_out(value: str) -> datetime.date: + return datetime.date.fromisoformat(value) -def convert_time_out(value: bytes) -> datetime.time: +@decode_to_string +def convert_time_out(value: str) -> datetime.time: """ If the value is a time, convert it to a UUID instance. """ - return datetime.time.fromisoformat(value.decode("utf8")) + return datetime.time.fromisoformat(value) -def convert_seconds_out(value: bytes) -> datetime.timedelta: +@decode_to_string +def convert_seconds_out(value: str) -> datetime.timedelta: """ If the value is from a seconds column, convert it to a timedelta instance. """ - return datetime.timedelta(seconds=float(value.decode("utf8"))) + return datetime.timedelta(seconds=float(value)) -def convert_boolean_out(value: bytes) -> bool: +@decode_to_string +def convert_boolean_out(value: str) -> bool: """ If the value is from a boolean column, convert it to a bool value. """ - _value = value.decode("utf8") - return _value == "1" + return value == "1" -def convert_timestamp_out(value: bytes) -> datetime.datetime: +@decode_to_string +def convert_timestamp_out(value: str) -> datetime.datetime: """ If the value is from a timestamp column, convert it to a datetime value. """ - return datetime.datetime.fromisoformat(value.decode("utf8")) + return datetime.datetime.fromisoformat(value) -def convert_timestamptz_out(value: bytes) -> datetime.datetime: +@decode_to_string +def convert_timestamptz_out(value: str) -> datetime.datetime: """ If the value is from a timestamptz column, convert it to a datetime value, with a timezone of UTC. """ - _value = datetime.datetime.fromisoformat(value.decode("utf8")) - _value = _value.replace(tzinfo=datetime.timezone.utc) - return _value + return datetime.datetime.fromisoformat(value).replace( + tzinfo=datetime.timezone.utc + ) -def convert_array_out(value: bytes) -> t.List: +@decode_to_string +def convert_array_out(value: str) -> t.List: """ If the value if from an array column, deserialise the string back into a list. """ - return load_json(value.decode("utf8")) - - -def convert_M2M_out(value: bytes) -> t.List: - _value = value.decode("utf8") - return _value.split(",") - - -sqlite3.register_converter("Numeric", convert_numeric_out) -sqlite3.register_converter("Integer", convert_int_out) -sqlite3.register_converter("UUID", convert_uuid_out) -sqlite3.register_converter("Date", convert_date_out) -sqlite3.register_converter("Time", convert_time_out) -sqlite3.register_converter("Seconds", convert_seconds_out) -sqlite3.register_converter("Boolean", convert_boolean_out) -sqlite3.register_converter("Timestamp", convert_timestamp_out) -sqlite3.register_converter("Timestamptz", convert_timestamptz_out) -sqlite3.register_converter("Array", convert_array_out) -sqlite3.register_converter("M2M", convert_M2M_out) - -sqlite3.register_adapter(Decimal, convert_numeric_in) -sqlite3.register_adapter(uuid.UUID, convert_uuid_in) -sqlite3.register_adapter(datetime.time, convert_time_in) -sqlite3.register_adapter(datetime.date, convert_date_in) -sqlite3.register_adapter(datetime.datetime, convert_datetime_in) -sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in) -sqlite3.register_adapter(list, convert_array_in) + return load_json(value) + + +def convert_complex_array_out(value: bytes, converter: t.Callable): + """ + This is used to handle arrays of things like timestamps, which we can't + just load from JSON without doing additional work to convert the elements + back into Python objects. + """ + parsed = load_json(value.decode("utf8")) + + def convert_list(list_value: t.List): + output = [] + + for value in list_value: + if isinstance(value, list): + # For nested arrays + output.append(convert_list(value)) + elif isinstance(value, str): + output.append(converter(value)) + else: + output.append(value) + + return output + + if isinstance(parsed, list): + return convert_list(parsed) + else: + return parsed + + +@decode_to_string +def convert_M2M_out(value: str) -> t.List: + return value.split(",") + + +############################################################################### +# Register the basic converters + +CONVERTERS = { + "NUMERIC": convert_numeric_out, + "INTEGER": convert_int_out, + "UUID": convert_uuid_out, + "DATE": convert_date_out, + "TIME": convert_time_out, + "SECONDS": convert_seconds_out, + "BOOLEAN": convert_boolean_out, + "TIMESTAMP": convert_timestamp_out, + "TIMESTAMPTZ": convert_timestamptz_out, + "M2M": convert_M2M_out, +} + +for column_name, converter in CONVERTERS.items(): + sqlite3.register_converter(column_name, converter) + +############################################################################### +# Register the array converters + +# The ARRAY column type handles values which can be easily serialised to and +# from JSON. +sqlite3.register_converter("ARRAY", convert_array_out) + +# We have special column types for arrays of timestamps etc, as simply loading +# the JSON isn't sufficient. +for column_name in ("TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME"): + sqlite3.register_converter( + f"ARRAY_{column_name}", + partial( + convert_complex_array_out, + converter=CONVERTERS[column_name], + ), + ) ############################################################################### diff --git a/tests/columns/test_array.py b/tests/columns/test_array.py index 6e325fbd4..4677ef995 100644 --- a/tests/columns/test_array.py +++ b/tests/columns/test_array.py @@ -1,6 +1,15 @@ +import datetime from unittest import TestCase -from piccolo.columns.column_types import Array, BigInt, Integer +from piccolo.columns.column_types import ( + Array, + BigInt, + Date, + Integer, + Time, + Timestamp, + Timestamptz, +) from piccolo.table import Table from tests.base import engines_only, sqlite_only @@ -22,7 +31,7 @@ def test_array_default(self): class TestArray(TestCase): """ - Make sure an Array column can be created, and work correctly. + Make sure an Array column can be created, and works correctly. """ def setUp(self): @@ -166,6 +175,84 @@ def test_cat_sqlite(self): ) +############################################################################### +# Date and time arrays + + +class DateTimeArrayTable(Table): + date = Array(Date()) + time = Array(Time()) + timestamp = Array(Timestamp()) + timestamptz = Array(Timestamptz()) + date_nullable = Array(Date(), null=True) + time_nullable = Array(Time(), null=True) + timestamp_nullable = Array(Timestamp(), null=True) + timestamptz_nullable = Array(Timestamptz(), null=True) + + +class TestDateTimeArray(TestCase): + """ + Make sure that data can be stored and retrieved when using arrays of + date / time / timestamp. + + We have to serialise / deserialise it in a special way in SQLite, hence + the tests. + + """ + + def setUp(self): + DateTimeArrayTable.create_table().run_sync() + + def tearDown(self): + DateTimeArrayTable.alter().drop_table().run_sync() + + @engines_only("postgres", "sqlite") + def test_storage(self): + test_date = datetime.date(year=2024, month=1, day=1) + test_time = datetime.time(hour=12, minute=0) + test_timestamp = datetime.datetime( + year=2024, month=1, day=1, hour=12, minute=0 + ) + test_timestamptz = datetime.datetime( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + tzinfo=datetime.timezone.utc, + ) + + DateTimeArrayTable( + { + DateTimeArrayTable.date: [test_date], + DateTimeArrayTable.time: [test_time], + DateTimeArrayTable.timestamp: [test_timestamp], + DateTimeArrayTable.timestamptz: [test_timestamptz], + DateTimeArrayTable.date_nullable: None, + DateTimeArrayTable.time_nullable: None, + DateTimeArrayTable.timestamp_nullable: None, + DateTimeArrayTable.timestamptz_nullable: None, + } + ).save().run_sync() + + row = DateTimeArrayTable.objects().first().run_sync() + assert row is not None + + self.assertListEqual(row.date, [test_date]) + self.assertListEqual(row.time, [test_time]) + self.assertListEqual(row.timestamp, [test_timestamp]) + self.assertListEqual(row.timestamptz, [test_timestamptz]) + + self.assertIsNone(row.date_nullable) + self.assertIsNone(row.time_nullable) + self.assertIsNone(row.timestamp_nullable) + self.assertIsNone(row.timestamptz_nullable) + + +############################################################################### +# Nested arrays + + class NestedArrayTable(Table): value = Array(base_column=Array(base_column=BigInt()))