Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1024 Add Concat function #1025

Merged
merged 8 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/piccolo/functions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ String functions

.. currentmodule:: piccolo.query.functions.string

Concat
------

.. autoclass:: Concat

Length
------

Expand Down
63 changes: 28 additions & 35 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,47 +86,38 @@ class ConcatDelegate:

def get_querystring(
self,
column_name: str,
value: t.Union[str, Varchar, Text],
column: Column,
value: t.Union[str, Column, QueryString],
reverse: bool = False,
) -> QueryString:
if isinstance(value, (Varchar, Text)):
column: Column = value
"""
:param reverse:
By default the value is appended to the column's value. If
``reverse=True`` then the value is prepended to the column's
value instead.

"""
if isinstance(value, Column):
if len(column._meta.call_chain) > 0:
raise ValueError(
"Adding values across joins isn't currently supported."
)
other_column_name = column._meta.db_column_name
if reverse:
return QueryString(
Concat.template.format(
value_1=other_column_name, value_2=column_name
)
)
else:
return QueryString(
Concat.template.format(
value_1=column_name, value_2=other_column_name
)
)
elif isinstance(value, str):
if reverse:
value_1 = QueryString("CAST({} AS text)", value)
return QueryString(
Concat.template.format(value_1="{}", value_2=column_name),
value_1,
)
else:
value_2 = QueryString("CAST({} AS text)", value)
return QueryString(
Concat.template.format(value_1=column_name, value_2="{}"),
value_2,
)
else:
value = QueryString("CAST({} AS TEXT)", value)
elif not isinstance(value, QueryString):
raise ValueError(
"Only str, Varchar columns, and Text columns can be added."
"Only str, Column and QueryString values can be added."
)

args = [value, column] if reverse else [column, value]

# We use the concat operator instead of the concat function, because
# this is what we historically used, and they treat null values
# differently.
return QueryString(
Concat.template.format(value_1="{}", value_2="{}"), *args
)


class MathDelegate:
"""
Expand Down Expand Up @@ -340,12 +331,13 @@ def column_type(self):

def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name, value=value
column=self,
value=value,
)

def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name,
column=self,
value=value,
reverse=True,
)
Expand Down Expand Up @@ -442,12 +434,13 @@ def __init__(

def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name, value=value
column=self,
value=value,
)

def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString:
return self.concat_delegate.get_querystring(
column_name=self._meta.db_column_name,
column=self,
value=value,
reverse=True,
)
Expand Down
3 changes: 2 additions & 1 deletion piccolo/query/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from .aggregate import Avg, Count, Max, Min, Sum
from .datetime import Day, Extract, Hour, Month, Second, Strftime, Year
from .math import Abs, Ceil, Floor, Round
from .string import Length, Lower, Ltrim, Reverse, Rtrim, Upper
from .string import Concat, Length, Lower, Ltrim, Reverse, Rtrim, Upper
from .type_conversion import Cast

__all__ = (
"Abs",
"Avg",
"Cast",
"Ceil",
"Concat",
"Count",
"Day",
"Extract",
Expand Down
45 changes: 45 additions & 0 deletions piccolo/query/functions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

"""

import typing as t

from piccolo.columns.base import Column
from piccolo.columns.column_types import Text, Varchar
from piccolo.querystring import QueryString

from .base import Function


Expand Down Expand Up @@ -63,11 +69,50 @@ class Upper(Function):
function_name = "UPPER"


class Concat(QueryString):
def __init__(
self,
*args: t.Union[Column, QueryString, str],
alias: t.Optional[str] = None,
):
"""
Concatenate multiple values into a single string.

.. note::
Null values are ignored, so ``null + '!!!'`` returns ``!!!``,
not ``null``.

.. warning::
For SQLite, this is only available in version 3.44.0 and above.

"""
if len(args) < 2:
raise ValueError("At least two values must be passed in.")

placeholders = ", ".join("{}" for _ in args)

processed_args: t.List[t.Union[QueryString, Column]] = []

for arg in args:
if isinstance(arg, str) or (
isinstance(arg, Column)
and not isinstance(arg, (Varchar, Text))
):
processed_args.append(QueryString("CAST({} AS TEXT)", arg))
else:
processed_args.append(arg)

super().__init__(
f"CONCAT({placeholders})", *processed_args, alias=alias
)


__all__ = (
"Length",
"Lower",
"Ltrim",
"Reverse",
"Rtrim",
"Upper",
"Concat",
)
36 changes: 34 additions & 2 deletions tests/query/functions/test_string.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from piccolo.query.functions.string import Upper
import pytest

from piccolo.query.functions.string import Concat, Upper
from tests.base import engine_version_lt, is_running_sqlite
from tests.example_apps.music.tables import Band

from .base import BandTest


class TestUpperFunction(BandTest):
class TestUpper(BandTest):

def test_column(self):
"""
Expand All @@ -23,3 +26,32 @@ def test_joined_column(self):
"""
response = Band.select(Upper(Band.manager._.name)).run_sync()
self.assertListEqual(response, [{"upper": "GUIDO"}])


@pytest.mark.skipif(
is_running_sqlite() and engine_version_lt(3.44),
reason="SQLite version not supported",
)
class TestConcat(BandTest):

def test_column_and_string(self):
response = Band.select(
Concat(Band.name, "!!!", alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas!!!"}])

def test_column_and_column(self):
response = Band.select(
Concat(Band.name, Band.popularity, alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas1000"}])

def test_join(self):
response = Band.select(
Concat(Band.name, "-", Band.manager._.name, alias="name")
).run_sync()
self.assertListEqual(response, [{"name": "Pythonistas-Guido"}])

def test_min_args(self):
with self.assertRaises(ValueError):
Concat()
Loading