From 4ae9fc720d23c0a02362c7d9a954e2ab82b04f5b Mon Sep 17 00:00:00 2001 From: Yu Ishihara Date: Wed, 22 Nov 2023 12:49:02 +0900 Subject: [PATCH] Avoid raising exception when info has inconsistent values --- nnabla_rl/utils/data.py | 13 +++++++++---- tests/utils/test_data.py | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/nnabla_rl/utils/data.py b/nnabla_rl/utils/data.py index e8c01ecb..6a451af4 100644 --- a/nnabla_rl/utils/data.py +++ b/nnabla_rl/utils/data.py @@ -18,6 +18,7 @@ import numpy as np import nnabla as nn +from nnabla_rl.logger import logger from nnabla_rl.typing import TupledData T = TypeVar('T') @@ -77,10 +78,14 @@ def marshal_dict_experiences(dict_experiences: Sequence[Dict[str, Any]]) -> Dict dict_of_list = list_of_dict_to_dict_of_list(dict_experiences) marshaled_experiences = {} for key, data in dict_of_list.items(): - if isinstance(data[0], Dict): - marshaled_experiences.update({key: marshal_dict_experiences(data)}) - else: - marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))}) + try: + if isinstance(data[0], Dict): + marshaled_experiences.update({key: marshal_dict_experiences(data)}) + else: + marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))}) + except ValueError as e: + # do nothing + logger.warn(f'key: {key} contains inconsistent elements!. Details: {e}') return marshaled_experiences diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 9a47362a..820af35a 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -15,6 +15,7 @@ import numpy as np import pytest +from packaging.version import parse import nnabla as nn import nnabla_rl.environments as E @@ -133,6 +134,27 @@ def test_marshal_triple_nested_dict_experiences(self): np.testing.assert_allclose(np.asarray(key1_experiences), 1) np.testing.assert_allclose(np.asarray(key2_experiences), 2) + def test_marashal_dict_experiences_with_inhomogeneous_part(self): + installed_numpy_version = parse(np.__version__) + numpy_version1_24 = parse('1.24.0') + + if installed_numpy_version < numpy_version1_24: + # no need to test + return + + experiences = {'key1': 1, 'key2': 2} + inhomgeneous_experiences = {'key1': np.empty(shape=(6, )), 'key2': 2} + dict_experiences = [{'key_parent': experiences}, {'key_parent': inhomgeneous_experiences}] + + marshaled_experience = marshal_dict_experiences(dict_experiences) + + assert 'key1' not in marshaled_experience['key_parent'] + + key2_experiences = marshaled_experience['key_parent']['key2'] + assert key2_experiences.shape == (2, 1) + + np.testing.assert_allclose(np.asarray(key2_experiences), 2) + def test_list_of_dict_to_dict_of_list(self): list_of_dict = [{'key1': 1, 'key2': 2}, {'key1': 1, 'key2': 2}] dict_of_list = list_of_dict_to_dict_of_list(list_of_dict)