Skip to content

Commit

Permalink
1056 Make refresh work when foreign key is set to null, and with `l…
Browse files Browse the repository at this point in the history
…oad_json` (#1060)

* make `refresh` work when fk is set to null, and with `load_json`

* add missing `return`

* fix tests

* mention in the docs how `refresh` is useful in unit tests

* update docs

* add `TestRefreshWithLoadJSON`

* improve refresh docs

* fix test
  • Loading branch information
dantownsend authored Jul 30, 2024
1 parent 9a4a919 commit 8c1563e
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 21 deletions.
14 changes: 14 additions & 0 deletions docs/src/piccolo/query_types/objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,20 @@ It works with ``prefetch`` too:
>>> band.manager.name
"New value"
``refresh`` is very useful in unit tests:

.. code-block:: python
# If we have an instance:
band = await Band.objects().where(Band.name == "Pythonistas").first()
# Call an API endpoint which updates the object (e.g. with httpx):
await client.patch(f"/band/{band.id}/", json={"popularity": 5000})
# Make sure the instance was updated:
await band.refresh()
assert band.popularity == 5000
-------------------------------------------------------------------------------

Query clauses
Expand Down
4 changes: 1 addition & 3 deletions piccolo/query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ async def _process_results(self, results) -> QueryResponseType:
if column._alias is not None:
json_column_names.append(column._alias)
elif len(column._meta.call_chain) > 0:
json_column_names.append(
column._meta.get_default_alias().replace("$", ".")
)
json_column_names.append(column._meta.get_default_alias())
else:
json_column_names.append(column._meta.name)

Expand Down
42 changes: 28 additions & 14 deletions piccolo/query/methods/refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t

from piccolo.utils.encoding import JSONDict
from piccolo.utils.sync import run_sync

if t.TYPE_CHECKING: # pragma: no cover
Expand All @@ -20,13 +21,17 @@ class Refresh:
:param columns:
Which columns to refresh - it not specified, then all columns are
refreshed.
:param load_json:
Whether to load ``JSON`` / ``JSONB`` columns as objects, instead of
just a string.
"""

def __init__(
self,
instance: Table,
columns: t.Optional[t.Sequence[Column]] = None,
load_json: bool = False,
):
self.instance = instance

Expand All @@ -42,6 +47,7 @@ def __init__(
)

self.columns = columns
self.load_json = load_json

@property
def _columns(self) -> t.Sequence[Column]:
Expand Down Expand Up @@ -94,6 +100,24 @@ def _get_columns(self, instance: Table, columns: t.Sequence[Column]):

return select_columns

def _update_instance(self, instance: Table, data_dict: t.Dict):
"""
Update the table instance. It is called recursively, if the instance
has child instances.
"""
for key, value in data_dict.items():
if isinstance(value, dict) and not isinstance(value, JSONDict):
# If the value is a dict, then it's a child instance.
if all(i is None for i in value.values()):
# If all values in the nested object are None, then we can
# safely assume that the object itself is null, as the
# primary key value must be null.
setattr(instance, key, None)
else:
self._update_instance(getattr(instance, key), value)
else:
setattr(instance, key, value)

async def run(
self, in_pool: bool = True, node: t.Optional[str] = None
) -> Table:
Expand Down Expand Up @@ -128,30 +152,20 @@ async def run(
instance=self.instance, columns=columns
)

updated_values = (
data_dict = (
await instance.__class__.select(*select_columns)
.where(pk_column == primary_key_value)
.output(nested=True, load_json=self.load_json)
.first()
.run(node=node, in_pool=in_pool)
)

if updated_values is None:
if data_dict is None:
raise ValueError(
"The object doesn't exist in the database any more."
)

for key, value in updated_values.items():
# For prefetched objects, make sure we update them correctly
object_to_update = instance
column_name = key

if "." in key:
path = key.split(".")
column_name = path.pop()
for i in path:
object_to_update = getattr(object_to_update, i)

setattr(object_to_update, column_name, value)
self._update_instance(instance=instance, data_dict=data_dict)

return instance

Expand Down
3 changes: 3 additions & 0 deletions piccolo/query/methods/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def output(self: Self, *, load_json: bool) -> Self: ...
def output(self: Self, *, load_json: bool, as_list: bool) -> SelectJSON: # type: ignore # noqa: E501
...

@t.overload
def output(self: Self, *, load_json: bool, nested: bool) -> Self: ...

@t.overload
def output(self: Self, *, nested: bool) -> Self: ...

Expand Down
10 changes: 8 additions & 2 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,9 @@ def remove(self) -> Delete:
)

def refresh(
self, columns: t.Optional[t.Sequence[Column]] = None
self,
columns: t.Optional[t.Sequence[Column]] = None,
load_json: bool = False,
) -> Refresh:
"""
Used to fetch the latest data for this instance from the database.
Expand All @@ -551,6 +553,10 @@ def refresh(
If you only want to refresh certain columns, specify them here.
Otherwise all columns are refreshed.
:param load_json:
Whether to load ``JSON`` / ``JSONB`` columns as objects, instead of
just a string.
Example usage::
# Get an instance from the database.
Expand All @@ -564,7 +570,7 @@ def refresh(
instance.refresh().run_sync()
"""
return Refresh(instance=self, columns=columns)
return Refresh(instance=self, columns=columns, load_json=load_json)

@t.overload
def get_related(
Expand Down
43 changes: 42 additions & 1 deletion piccolo/utils/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,46 @@ def dump_json(data: t.Any, pretty: bool = False) -> str:
return json.dumps(data, **params) # type: ignore


class JSONDict(dict):
"""
Once we have parsed a JSON string into a dictionary, we can't distinguish
it from other dictionaries.
Sometimes we might want to - for example::
>>> await Album.select(
... Album.all_columns(),
... Album.recording_studio.all_columns()
... ).output(
... nested=True,
... load_json=True
... )
[{
'id': 1,
'band': 1,
'name': 'Awesome album 1',
'recorded_at': {
'id': 1,
'facilities': {'restaurant': True, 'mixing_desk': True},
'name': 'Abbey Road'
},
'release_date': datetime.date(2021, 1, 1)
}]
Facilities could be mistaken for a table.
"""

...


def load_json(data: str) -> t.Any:
return orjson.loads(data) if ORJSON else json.loads(data) # type: ignore
response = (
orjson.loads(data) if ORJSON else json.loads(data) # type: ignore
)

if isinstance(response, dict):
return JSONDict(**response)

return response
7 changes: 7 additions & 0 deletions tests/columns/m2m/test_m2m.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from unittest import TestCase

from piccolo.utils.encoding import JSONDict
from tests.base import engines_skip

try:
Expand Down Expand Up @@ -376,6 +377,9 @@ def test_select_single(self):

if isinstance(column, UUID):
self.assertIn(type(returned_value), (uuid.UUID, asyncpgUUID))
elif isinstance(column, (JSON, JSONB)):
self.assertEqual(type(returned_value), JSONDict)
self.assertEqual(original_value, returned_value)
else:
self.assertEqual(
type(original_value),
Expand All @@ -401,6 +405,9 @@ def test_select_single(self):
if isinstance(column, UUID):
self.assertIn(type(returned_value), (uuid.UUID, asyncpgUUID))
self.assertEqual(str(original_value), str(returned_value))
elif isinstance(column, (JSON, JSONB)):
self.assertEqual(type(returned_value), JSONDict)
self.assertEqual(original_value, returned_value)
else:
self.assertEqual(
type(original_value),
Expand Down
72 changes: 71 additions & 1 deletion tests/table/test_refresh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import typing as t

from tests.base import DBTestCase, TableTest
from tests.example_apps.music.tables import Band, Concert, Manager, Venue
from tests.example_apps.music.tables import (
Band,
Concert,
Manager,
RecordingStudio,
Venue,
)


class TestRefresh(DBTestCase):
Expand Down Expand Up @@ -215,6 +223,27 @@ def test_updated_foreign_key(self) -> None:
self.assertEqual(band.manager.id, new_manager.id)
self.assertEqual(band.manager.name, "New Manager")

def test_foreign_key_set_to_null(self):
"""
Make sure that if the foreign key was set to null, that ``refresh``
sets the nested object to ``None``.
"""
band = (
Band.objects(Band.manager)
.where(Band.name == "Pythonistas")
.first()
.run_sync()
)
assert band is not None

# Remove the manager from band
Band.update({Band.manager: None}, force=True).run_sync()

# Refresh `band`, and make sure the foreign key value is now `None`,
# instead of a nested object.
band.refresh().run_sync()
self.assertIsNone(band.manager)

def test_exception(self) -> None:
"""
We don't currently let the user refresh specific fields from nested
Expand All @@ -225,3 +254,44 @@ def test_exception(self) -> None:

# Shouldn't raise an exception:
self.concert.refresh(columns=[Concert.band_1]).run_sync()


class TestRefreshWithLoadJSON(TableTest):

tables = [RecordingStudio]

def setUp(self):
super().setUp()

self.recording_studio = RecordingStudio(
{RecordingStudio.facilities: {"piano": True}}
)
self.recording_studio.save().run_sync()

def test_load_json(self):
"""
Make sure we can refresh an object, and load the JSON as a Python
object.
"""
RecordingStudio.update(
{RecordingStudio.facilities: {"electric piano": True}},
force=True,
).run_sync()

# Refresh without load_json:
self.recording_studio.refresh().run_sync()

self.assertEqual(
# Remove the white space, because some versions of Python add
# whitespace around JSON, and some don't.
self.recording_studio.facilities.replace(" ", ""),
'{"electricpiano":true}',
)

# Refresh with load_json:
self.recording_studio.refresh(load_json=True).run_sync()

self.assertDictEqual(
t.cast(dict, self.recording_studio.facilities),
{"electric piano": True},
)

0 comments on commit 8c1563e

Please sign in to comment.