diff --git a/docs/src/piccolo/functions/string.rst b/docs/src/piccolo/functions/string.rst index dbc09125a..8d991c956 100644 --- a/docs/src/piccolo/functions/string.rst +++ b/docs/src/piccolo/functions/string.rst @@ -3,6 +3,11 @@ String functions .. currentmodule:: piccolo.query.functions.string +Concat +------ + +.. autoclass:: Concat + Length ------ diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 9c74b4a52..6140f6d13 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -86,47 +86,38 @@ class ConcatDelegate: def get_querystring( self, - column_name: str, - value: t.Union[str, Varchar, Text], + column: Column, + value: t.Union[str, Column, QueryString], reverse: bool = False, ) -> QueryString: - if isinstance(value, (Varchar, Text)): - column: Column = value + """ + :param reverse: + By default the value is appended to the column's value. If + ``reverse=True`` then the value is prepended to the column's + value instead. + + """ + if isinstance(value, Column): if len(column._meta.call_chain) > 0: raise ValueError( "Adding values across joins isn't currently supported." ) - other_column_name = column._meta.db_column_name - if reverse: - return QueryString( - Concat.template.format( - value_1=other_column_name, value_2=column_name - ) - ) - else: - return QueryString( - Concat.template.format( - value_1=column_name, value_2=other_column_name - ) - ) elif isinstance(value, str): - if reverse: - value_1 = QueryString("CAST({} AS text)", value) - return QueryString( - Concat.template.format(value_1="{}", value_2=column_name), - value_1, - ) - else: - value_2 = QueryString("CAST({} AS text)", value) - return QueryString( - Concat.template.format(value_1=column_name, value_2="{}"), - value_2, - ) - else: + value = QueryString("CAST({} AS TEXT)", value) + elif not isinstance(value, QueryString): raise ValueError( - "Only str, Varchar columns, and Text columns can be added." + "Only str, Column and QueryString values can be added." ) + args = [value, column] if reverse else [column, value] + + # We use the concat operator instead of the concat function, because + # this is what we historically used, and they treat null values + # differently. + return QueryString( + Concat.template.format(value_1="{}", value_2="{}"), *args + ) + class MathDelegate: """ @@ -340,12 +331,13 @@ def column_type(self): def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.db_column_name, value=value + column=self, + value=value, ) def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.db_column_name, + column=self, value=value, reverse=True, ) @@ -442,12 +434,13 @@ def __init__( def __add__(self, value: t.Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.db_column_name, value=value + column=self, + value=value, ) def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: return self.concat_delegate.get_querystring( - column_name=self._meta.db_column_name, + column=self, value=value, reverse=True, ) diff --git a/piccolo/query/functions/__init__.py b/piccolo/query/functions/__init__.py index 8f8944d32..3163f6d1c 100644 --- a/piccolo/query/functions/__init__.py +++ b/piccolo/query/functions/__init__.py @@ -1,7 +1,7 @@ from .aggregate import Avg, Count, Max, Min, Sum from .datetime import Day, Extract, Hour, Month, Second, Strftime, Year from .math import Abs, Ceil, Floor, Round -from .string import Length, Lower, Ltrim, Reverse, Rtrim, Upper +from .string import Concat, Length, Lower, Ltrim, Reverse, Rtrim, Upper from .type_conversion import Cast __all__ = ( @@ -9,6 +9,7 @@ "Avg", "Cast", "Ceil", + "Concat", "Count", "Day", "Extract", diff --git a/piccolo/query/functions/string.py b/piccolo/query/functions/string.py index 556817a12..68b78219f 100644 --- a/piccolo/query/functions/string.py +++ b/piccolo/query/functions/string.py @@ -5,6 +5,12 @@ """ +import typing as t + +from piccolo.columns.base import Column +from piccolo.columns.column_types import Text, Varchar +from piccolo.querystring import QueryString + from .base import Function @@ -63,6 +69,44 @@ class Upper(Function): function_name = "UPPER" +class Concat(QueryString): + def __init__( + self, + *args: t.Union[Column, QueryString, str], + alias: t.Optional[str] = None, + ): + """ + Concatenate multiple values into a single string. + + .. note:: + Null values are ignored, so ``null + '!!!'`` returns ``!!!``, + not ``null``. + + .. warning:: + For SQLite, this is only available in version 3.44.0 and above. + + """ + if len(args) < 2: + raise ValueError("At least two values must be passed in.") + + placeholders = ", ".join("{}" for _ in args) + + processed_args: t.List[t.Union[QueryString, Column]] = [] + + for arg in args: + if isinstance(arg, str) or ( + isinstance(arg, Column) + and not isinstance(arg, (Varchar, Text)) + ): + processed_args.append(QueryString("CAST({} AS TEXT)", arg)) + else: + processed_args.append(arg) + + super().__init__( + f"CONCAT({placeholders})", *processed_args, alias=alias + ) + + __all__ = ( "Length", "Lower", @@ -70,4 +114,5 @@ class Upper(Function): "Reverse", "Rtrim", "Upper", + "Concat", ) diff --git a/tests/query/functions/test_string.py b/tests/query/functions/test_string.py index b87952634..bd3a8c2ab 100644 --- a/tests/query/functions/test_string.py +++ b/tests/query/functions/test_string.py @@ -1,10 +1,13 @@ -from piccolo.query.functions.string import Upper +import pytest + +from piccolo.query.functions.string import Concat, Upper +from tests.base import engine_version_lt, is_running_sqlite from tests.example_apps.music.tables import Band from .base import BandTest -class TestUpperFunction(BandTest): +class TestUpper(BandTest): def test_column(self): """ @@ -23,3 +26,32 @@ def test_joined_column(self): """ response = Band.select(Upper(Band.manager._.name)).run_sync() self.assertListEqual(response, [{"upper": "GUIDO"}]) + + +@pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.44), + reason="SQLite version not supported", +) +class TestConcat(BandTest): + + def test_column_and_string(self): + response = Band.select( + Concat(Band.name, "!!!", alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas!!!"}]) + + def test_column_and_column(self): + response = Band.select( + Concat(Band.name, Band.popularity, alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas1000"}]) + + def test_join(self): + response = Band.select( + Concat(Band.name, "-", Band.manager._.name, alias="name") + ).run_sync() + self.assertListEqual(response, [{"name": "Pythonistas-Guido"}]) + + def test_min_args(self): + with self.assertRaises(ValueError): + Concat()