Skip to content

Commit

Permalink
1024 Add Concat function (#1025)
Browse files Browse the repository at this point in the history
* add `Concat` function

* update concat delegate

* go back to using the concat operator in concat delegate to keep behaviour consistent

* add tests

* fix cockroachdb test

* fix mypy error

* add notes that concat isn't available in older sqlite versions

* move decorator to right file
  • Loading branch information
dantownsend authored Jun 15, 2024
1 parent 13479c0 commit 98ad5d5
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 38 deletions.
5 changes: 5 additions & 0 deletions docs/src/piccolo/functions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ String functions

.. currentmodule:: piccolo.query.functions.string

Concat
------

.. autoclass:: Concat

Length
------

Expand Down
63 changes: 28 additions & 35 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,47 +86,38 @@ class ConcatDelegate:

def get_querystring(
self,
column_name: str,
value: t.Union[str, Varchar, Text],
column: Column,
value: t.Union[str, Column, QueryString],
reverse: bool = False,
) -> QueryString:
if isinstance(value, (Varchar, Text)):
column: Column = value
"""
:param reverse:
By default the value is appended to the column's value. If
``reverse=True`` then the value is prepended to the column's
value instead.
"""
if isinstance(value, Column):
if len(column._meta.call_chain) > 0:
raise ValueError(
"Adding values across joins isn't currently supported."
)
other_column_name = column._meta.db_column_name
if reverse:
return QueryString(
Concat.template.format(
value_1=other_column_name, value_2=column_name
)
)
else:
return QueryString(
Concat.template.format(
value_1=column_name, value_2=other_column_name
)
)
elif isinstance(value, str):
if reverse:
value_1 = QueryString("CAST({} AS text)", value)
return QueryString(
Concat.template.format(value_1="{}", value_2=column_name),
value_1,
)
else:
value_2 = QueryString("CAST({} AS text)", value)
return QueryString(
Concat.template.format(value_1=column_name, value_2="{}"),
value_2,
)
else:
value = QueryString("CAST({} AS TEXT)", value)
elif not isinstance(value, QueryString):
raise ValueError(
"Only str, Varchar columns, and Text columns can be added."
"Only str, Column and QueryString values can be added."
)

args = [value, column] if reverse else [column, value]

# We use the concat operator instead of the concat function, because
# this is what we historically used, and they treat null values
# differently.
return QueryString(
Concat.template.format(value_1="{}", value_2="{}"), *args
)


class MathDelegate:
"""
Expand Down Expand Up @@ -340,12 +331,13 @@ def column_type(self):

def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name, value=value
column=self,
value=value,
)

def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name,
column=self,
value=value,
reverse=True,
)
Expand Down Expand Up @@ -442,12 +434,13 @@ def __init__(

def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name, value=value
column=self,
value=value,
)

def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name,
column=self,
value=value,
reverse=True,
)
Expand Down
3 changes: 2 additions & 1 deletion piccolo/query/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from .aggregate import Avg, Count, Max, Min, Sum
from .datetime import Day, Extract, Hour, Month, Second, Strftime, Year
from .math import Abs, Ceil, Floor, Round
from .string import Length, Lower, Ltrim, Reverse, Rtrim, Upper
from .string import Concat, Length, Lower, Ltrim, Reverse, Rtrim, Upper
from .type_conversion import Cast

__all__ = (
"Abs",
"Avg",
"Cast",
"Ceil",
"Concat",
"Count",
"Day",
"Extract",
Expand Down
45 changes: 45 additions & 0 deletions piccolo/query/functions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"""

import typing as t

from piccolo.columns.base import Column
from piccolo.columns.column_types import Text, Varchar
from piccolo.querystring import QueryString

from .base import Function


Expand Down Expand Up @@ -63,11 +69,50 @@ class Upper(Function):
function_name = "UPPER"


class Concat(QueryString):
def __init__(
self,
*args: t.Union[Column, QueryString, str],
alias: t.Optional[str] = None,
):
"""
Concatenate multiple values into a single string.
.. note::
Null values are ignored, so ``null + '!!!'`` returns ``!!!``,
not ``null``.
.. warning::
For SQLite, this is only available in version 3.44.0 and above.
"""
if len(args) < 2:
raise ValueError("At least two values must be passed in.")

placeholders = ", ".join("{}" for _ in args)

processed_args: t.List[t.Union[QueryString, Column]] = []

for arg in args:
if isinstance(arg, str) or (
isinstance(arg, Column)
and not isinstance(arg, (Varchar, Text))
):
processed_args.append(QueryString("CAST({} AS TEXT)", arg))
else:
processed_args.append(arg)

super().__init__(
f"CONCAT({placeholders})", *processed_args, alias=alias
)


__all__ = (
"Length",
"Lower",
"Ltrim",
"Reverse",
"Rtrim",
"Upper",
"Concat",
)
36 changes: 34 additions & 2 deletions tests/query/functions/test_string.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from piccolo.query.functions.string import Upper
import pytest

from piccolo.query.functions.string import Concat, Upper
from tests.base import engine_version_lt, is_running_sqlite
from tests.example_apps.music.tables import Band

from .base import BandTest


class TestUpperFunction(BandTest):
class TestUpper(BandTest):

def test_column(self):
"""
Expand All @@ -23,3 +26,32 @@ def test_joined_column(self):
"""
response = Band.select(Upper(Band.manager._.name)).run_sync()
self.assertListEqual(response, [{"upper": "GUIDO"}])


@pytest.mark.skipif(
is_running_sqlite() and engine_version_lt(3.44),
reason="SQLite version not supported",
)
class TestConcat(BandTest):

def test_column_and_string(self):
response = Band.select(
Concat(Band.name, "!!!", alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas!!!"}])

def test_column_and_column(self):
response = Band.select(
Concat(Band.name, Band.popularity, alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas1000"}])

def test_join(self):
response = Band.select(
Concat(Band.name, "-", Band.manager._.name, alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas-Guido"}])

def test_min_args(self):
with self.assertRaises(ValueError):
Concat()

0 comments on commit 98ad5d5

Please sign in to comment.