Skip to content

Commit

Permalink
refactor, so we have an Arrow function
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Oct 20, 2024
1 parent a58663b commit 976f3ef
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
34 changes: 5 additions & 29 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Band(Table):

if t.TYPE_CHECKING: # pragma: no cover
from piccolo.columns.base import ColumnMeta
from piccolo.query.functions.json import Arrow
from piccolo.table import Table


Expand Down Expand Up @@ -2335,33 +2336,6 @@ def __set__(self, obj, value: t.Union[str, t.Dict]):
obj.__dict__[self._meta.name] = value


class JSONQueryString(QueryString):
"""
Functionally this is basically the same as ``QueryString``, we just need
``Query._process_results`` to be able to differentiate it from a normal
``QueryString`` just incase the user specified ``.output(load_json=True)``.
"""

def clean_value(self, value: t.Any):
if not isinstance(value, (str, QueryString)):
value = dump_json(value)
return value

def __eq__(self, value) -> QueryString: # type: ignore[override]
value = self.clean_value(value)
return QueryString("{} = {}", self, value)

def __ne__(self, value) -> QueryString: # type: ignore[override]
value = self.clean_value(value)
return QueryString("{} != {}", self, value)

def eq(self, value) -> QueryString:
return self.__eq__(value)

def ne(self, value) -> QueryString:
return self.__ne__(value)


class JSONB(JSON):
"""
Used for storing JSON strings - Postgres only. The data is stored in a
Expand All @@ -2379,13 +2353,15 @@ class JSONB(JSON):
def column_type(self):
return "JSONB" # Must be defined, we override column_type() in JSON()

def arrow(self, key: str) -> JSONQueryString:
def arrow(self, key: str) -> Arrow:
"""
Allows part of the JSON structure to be returned - for example,
for {"a": 1}, and a key value of "a", then 1 will be returned.
"""
from piccolo.query.functions.json import Arrow

alias = self._alias or self._meta.get_default_alias()
return JSONQueryString("{} -> {}", self, key, alias=alias)
return Arrow(column=self, key=key, alias=alias)

###########################################################################
# Descriptors
Expand Down
5 changes: 3 additions & 2 deletions piccolo/query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import typing as t
from time import time

from piccolo.columns.column_types import JSON, JSONB, JSONQueryString
from piccolo.columns.column_types import JSON, JSONB
from piccolo.custom_types import QueryResponseType, TableInstance
from piccolo.query.functions.json import Arrow
from piccolo.query.mixins import ColumnsDelegate
from piccolo.querystring import QueryString
from piccolo.utils.encoding import load_json
Expand Down Expand Up @@ -73,7 +74,7 @@ async def _process_results(self, results) -> QueryResponseType:
for column in columns_delegate.selected_columns:
if isinstance(column, (JSON, JSONB)):
json_columns.append(column)
elif isinstance(column, JSONQueryString):
elif isinstance(column, Arrow):
if alias := column._alias:
json_column_names.append(alias)
else:
Expand Down
40 changes: 40 additions & 0 deletions piccolo/query/functions/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import typing as t

from piccolo.querystring import QueryString
from piccolo.utils.encoding import dump_json

if t.TYPE_CHECKING:
from piccolo.columns.column_types import JSONB


class Arrow(QueryString):
"""
Functionally this is basically the same as ``QueryString``, we just need
``Query._process_results`` to be able to differentiate it from a normal
``QueryString`` just in case the user specified
``.output(load_json=True)``.
"""

def __init__(self, column: JSONB, key: str, alias: t.Optional[str] = None):
super().__init__("{} -> {}", column, key, alias=alias)

def clean_value(self, value: t.Any):
if not isinstance(value, (str, QueryString)):
value = dump_json(value)
return value

def __eq__(self, value) -> QueryString: # type: ignore[override]
value = self.clean_value(value)
return QueryString("{} = {}", self, value)

def __ne__(self, value) -> QueryString: # type: ignore[override]
value = self.clean_value(value)
return QueryString("{} != {}", self, value)

def eq(self, value) -> QueryString:
return self.__eq__(value)

def ne(self, value) -> QueryString:
return self.__ne__(value)

0 comments on commit 976f3ef

Please sign in to comment.