From d082df7f1b53057e15c8cbbc7e662ec808c27722 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 10 Aug 2023 17:50:29 +0200 Subject: [PATCH] `SinglefileData`: Add `mode` keyword to `get_content` This allows a user to retrieve the content in bytes. Currently, a user is forced to use the more elaborate form: with singlefile.open(mode='rb') as handle: content = handle.read() or go directly through the repository interface which is a bit hidden and requires to redundantly specify the filename: content = singlefile.base.repository.get_object_content( singlefile.filename, mode='rb' ) these variants can now be simplified to: content = singlefile.get_content('rb') --- aiida/orm/nodes/data/singlefile.py | 7 ++++--- tests/orm/nodes/data/test_singlefile.py | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index c352f0acae..6a841083c9 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -92,12 +92,13 @@ def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.I with self.base.repository.open(path, mode=mode) as handle: yield handle - def get_content(self) -> str: + def get_content(self, mode: str = 'r') -> str | bytes: """Return the content of the single file stored for this data node. - :return: the content of the file as a string + :param mode: the mode with which to open the file handle (default: read mode) + :return: the content of the file as a string or bytes, depending on ``mode``. """ - with self.open(mode='r') as handle: # type: ignore[call-overload] + with self.open(mode=mode) as handle: # type: ignore[call-overload] return handle.read() def set_file(self, file: str | t.IO, filename: str | pathlib.Path | None = None) -> None: diff --git a/tests/orm/nodes/data/test_singlefile.py b/tests/orm/nodes/data/test_singlefile.py index 8bf08c4f27..f35a00b9b6 100644 --- a/tests/orm/nodes/data/test_singlefile.py +++ b/tests/orm/nodes/data/test_singlefile.py @@ -198,3 +198,11 @@ def test_from_string(): node = SinglefileData.from_string(content, filename).store() assert node.get_content() == content assert node.filename == filename + + +def test_get_content(): + """Test the :meth:`aiida.orm.nodes.data.singlefile.SinglefileData.get_content` method.""" + content = b'some\ncontent' + node = SinglefileData.from_string(content.decode('utf-8')).store() + assert node.get_content() == content.decode('utf-8') + assert node.get_content('rb') == content