Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1003 Support arrays of timestamp / timestamptz / date / time in SQLite #1004

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,14 @@ def column_type(self):
if engine_type in ("postgres", "cockroach"):
return f"{self.base_column.column_type}[]"
elif engine_type == "sqlite":
return "ARRAY"
inner_column = self._get_inner_column()
return (
f"ARRAY_{inner_column.column_type}"
if isinstance(
inner_column, (Date, Timestamp, Timestamptz, Time)
)
else "ARRAY"
)
raise Exception("Unrecognized engine type")

def _setup_base_column(self, table_class: t.Type[Table]):
Expand Down Expand Up @@ -2564,6 +2571,23 @@ def _get_dimensions(self, start: int = 0) -> int:
else:
return start + 1

def _get_inner_column(self) -> Column:
"""
A helper function to get the innermost ``Column`` for the array. For
example::

>>> Array(Varchar())._get_inner_column()
Varchar

>>> Array(Array(Varchar()))._get_inner_column()
Varchar

"""
if isinstance(self.base_column, Array):
return self.base_column._get_inner_column()
else:
return self.base_column

def _get_inner_value_type(self) -> t.Type:
"""
A helper function to get the innermost value type for the array. For
Expand All @@ -2576,10 +2600,7 @@ def _get_inner_value_type(self) -> t.Type:
str

"""
if isinstance(self.base_column, Array):
return self.base_column._get_inner_value_type()
else:
return self.base_column.value_type
return self._get_inner_column().value_type

def __getitem__(self, value: int) -> Array:
"""
Expand Down
228 changes: 171 additions & 57 deletions piccolo/engine/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uuid
from dataclasses import dataclass
from decimal import Decimal
from functools import partial, wraps

from piccolo.engine.base import Batch, Engine, validate_savepoint_name
from piccolo.engine.exceptions import TransactionError
Expand All @@ -35,14 +36,14 @@
# In


def convert_numeric_in(value):
def convert_numeric_in(value: Decimal) -> float:
"""
Convert any Decimal values into floats.
"""
return float(value)


def convert_uuid_in(value) -> str:
def convert_uuid_in(value: uuid.UUID) -> str:
"""
Converts the UUID value being passed into sqlite.
"""
Expand All @@ -56,7 +57,7 @@ def convert_time_in(value: datetime.time) -> str:
return value.isoformat()


def convert_date_in(value: datetime.date):
def convert_date_in(value: datetime.date) -> str:
"""
Converts the date value being passed into sqlite.
"""
Expand All @@ -74,122 +75,235 @@ def convert_datetime_in(value: datetime.datetime) -> str:
return str(value)


def convert_timedelta_in(value: datetime.timedelta):
def convert_timedelta_in(value: datetime.timedelta) -> float:
"""
Converts the timedelta value being passed into sqlite.
"""
return value.total_seconds()


def convert_array_in(value: list):
def convert_array_in(value: list) -> str:
"""
Converts a list value into a string.
Converts a list value into a string (it handles nested lists, and type like
dateime/ time / date which aren't usually JSON serialisable.).

"""
if value and type(value[0]) not in [str, int, float, list]:
raise ValueError("Can only serialise str, int, float, and list.")

return dump_json(value)
def serialise(data: list):
output = []

for item in data:
if isinstance(item, list):
output.append(serialise(item))
elif isinstance(
item, (datetime.datetime, datetime.time, datetime.date)
):
if adapter := ADAPTERS.get(type(item)):
output.append(adapter(item))
else:
raise ValueError("The adapter wasn't found.")
elif item is None or isinstance(item, (str, int, float, list)):
# We can safely JSON serialise these.
output.append(item)
else:
raise ValueError("We can't currently serialise this value.")

return output

return dump_json(serialise(value))


###############################################################################

# Register adapters

ADAPTERS: t.Dict[t.Type, t.Callable[[t.Any], t.Any]] = {
Decimal: convert_numeric_in,
uuid.UUID: convert_uuid_in,
datetime.time: convert_time_in,
datetime.date: convert_date_in,
datetime.datetime: convert_datetime_in,
datetime.timedelta: convert_timedelta_in,
list: convert_array_in,
}

for value_type, adapter in ADAPTERS.items():
sqlite3.register_adapter(value_type, adapter)

###############################################################################

# Out


def convert_numeric_out(value: bytes) -> Decimal:
def decode_to_string(converter: t.Callable[[str], t.Any]):
"""
This means we can use our converters with string and bytes. They are
passed bytes when used directly via SQLite, and are passed strings when
used by the array converters.
"""

@wraps(converter)
def wrapper(value: t.Union[str, bytes]) -> t.Any:
if isinstance(value, bytes):
return converter(value.decode("utf8"))
elif isinstance(value, str):
return converter(value)
else:
raise ValueError("Unsupported type")

return wrapper


@decode_to_string
def convert_numeric_out(value: str) -> Decimal:
"""
Convert float values into Decimals.
"""
return Decimal(value.decode("ascii"))
return Decimal(value)


def convert_int_out(value: bytes) -> int:
@decode_to_string
def convert_int_out(value: str) -> int:
"""
Make sure Integer values are actually of type int.
"""
return int(float(value))


def convert_uuid_out(value: bytes) -> uuid.UUID:
@decode_to_string
def convert_uuid_out(value: str) -> uuid.UUID:
"""
If the value is a uuid, convert it to a UUID instance.
"""
return uuid.UUID(value.decode("utf8"))
return uuid.UUID(value)


def convert_date_out(value: bytes) -> datetime.date:
return datetime.date.fromisoformat(value.decode("utf8"))
@decode_to_string
def convert_date_out(value: str) -> datetime.date:
return datetime.date.fromisoformat(value)


def convert_time_out(value: bytes) -> datetime.time:
@decode_to_string
def convert_time_out(value: str) -> datetime.time:
"""
If the value is a time, convert it to a UUID instance.
"""
return datetime.time.fromisoformat(value.decode("utf8"))
return datetime.time.fromisoformat(value)


def convert_seconds_out(value: bytes) -> datetime.timedelta:
@decode_to_string
def convert_seconds_out(value: str) -> datetime.timedelta:
"""
If the value is from a seconds column, convert it to a timedelta instance.
"""
return datetime.timedelta(seconds=float(value.decode("utf8")))
return datetime.timedelta(seconds=float(value))


def convert_boolean_out(value: bytes) -> bool:
@decode_to_string
def convert_boolean_out(value: str) -> bool:
"""
If the value is from a boolean column, convert it to a bool value.
"""
_value = value.decode("utf8")
return _value == "1"
return value == "1"


def convert_timestamp_out(value: bytes) -> datetime.datetime:
@decode_to_string
def convert_timestamp_out(value: str) -> datetime.datetime:
"""
If the value is from a timestamp column, convert it to a datetime value.
"""
return datetime.datetime.fromisoformat(value.decode("utf8"))
return datetime.datetime.fromisoformat(value)


def convert_timestamptz_out(value: bytes) -> datetime.datetime:
@decode_to_string
def convert_timestamptz_out(value: str) -> datetime.datetime:
"""
If the value is from a timestamptz column, convert it to a datetime value,
with a timezone of UTC.
"""
_value = datetime.datetime.fromisoformat(value.decode("utf8"))
_value = _value.replace(tzinfo=datetime.timezone.utc)
return _value
return datetime.datetime.fromisoformat(value).replace(
tzinfo=datetime.timezone.utc
)


def convert_array_out(value: bytes) -> t.List:
@decode_to_string
def convert_array_out(value: str) -> t.List:
"""
If the value if from an array column, deserialise the string back into a
list.
"""
return load_json(value.decode("utf8"))


def convert_M2M_out(value: bytes) -> t.List:
_value = value.decode("utf8")
return _value.split(",")


sqlite3.register_converter("Numeric", convert_numeric_out)
sqlite3.register_converter("Integer", convert_int_out)
sqlite3.register_converter("UUID", convert_uuid_out)
sqlite3.register_converter("Date", convert_date_out)
sqlite3.register_converter("Time", convert_time_out)
sqlite3.register_converter("Seconds", convert_seconds_out)
sqlite3.register_converter("Boolean", convert_boolean_out)
sqlite3.register_converter("Timestamp", convert_timestamp_out)
sqlite3.register_converter("Timestamptz", convert_timestamptz_out)
sqlite3.register_converter("Array", convert_array_out)
sqlite3.register_converter("M2M", convert_M2M_out)

sqlite3.register_adapter(Decimal, convert_numeric_in)
sqlite3.register_adapter(uuid.UUID, convert_uuid_in)
sqlite3.register_adapter(datetime.time, convert_time_in)
sqlite3.register_adapter(datetime.date, convert_date_in)
sqlite3.register_adapter(datetime.datetime, convert_datetime_in)
sqlite3.register_adapter(datetime.timedelta, convert_timedelta_in)
sqlite3.register_adapter(list, convert_array_in)
return load_json(value)


def convert_complex_array_out(value: bytes, converter: t.Callable):
"""
This is used to handle arrays of things like timestamps, which we can't
just load from JSON without doing additional work to convert the elements
back into Python objects.
"""
parsed = load_json(value.decode("utf8"))

def convert_list(list_value: t.List):
output = []

for value in list_value:
if isinstance(value, list):
# For nested arrays
output.append(convert_list(value))
elif isinstance(value, str):
output.append(converter(value))
else:
output.append(value)

return output

if isinstance(parsed, list):
return convert_list(parsed)
else:
return parsed


@decode_to_string
def convert_M2M_out(value: str) -> t.List:
return value.split(",")


###############################################################################
# Register the basic converters

CONVERTERS = {
"NUMERIC": convert_numeric_out,
"INTEGER": convert_int_out,
"UUID": convert_uuid_out,
"DATE": convert_date_out,
"TIME": convert_time_out,
"SECONDS": convert_seconds_out,
"BOOLEAN": convert_boolean_out,
"TIMESTAMP": convert_timestamp_out,
"TIMESTAMPTZ": convert_timestamptz_out,
"M2M": convert_M2M_out,
}

for column_name, converter in CONVERTERS.items():
sqlite3.register_converter(column_name, converter)

###############################################################################
# Register the array converters

# The ARRAY column type handles values which can be easily serialised to and
# from JSON.
sqlite3.register_converter("ARRAY", convert_array_out)

# We have special column types for arrays of timestamps etc, as simply loading
# the JSON isn't sufficient.
for column_name in ("TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME"):
sqlite3.register_converter(
f"ARRAY_{column_name}",
partial(
convert_complex_array_out,
converter=CONVERTERS[column_name],
),
)

###############################################################################

Expand Down
Loading
Loading