From 99558f9618f4d319a9f3882a7784c9f3568042f4 Mon Sep 17 00:00:00 2001 From: Yu Ishihara Date: Thu, 14 Mar 2024 16:18:46 +0900 Subject: [PATCH] Add format specifier to file writer --- nnabla_rl/writers/file_writer.py | 6 ++--- .../writers/evaluation_results_scalar%.3f.tsv | 2 ++ .../writers/evaluation_results_scalar%.5f.tsv | 2 ++ .../writers/evaluation_results_scalar%f.tsv | 2 ++ tests/writers/test_file_writer.py | 25 +++++++++++++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 test_resources/writers/evaluation_results_scalar%.3f.tsv create mode 100644 test_resources/writers/evaluation_results_scalar%.5f.tsv create mode 100644 test_resources/writers/evaluation_results_scalar%f.tsv diff --git a/nnabla_rl/writers/file_writer.py b/nnabla_rl/writers/file_writer.py index 9e5689af..ef47c7ba 100644 --- a/nnabla_rl/writers/file_writer.py +++ b/nnabla_rl/writers/file_writer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,14 +22,14 @@ class FileWriter(Writer): - def __init__(self, outdir, file_prefix): + def __init__(self, outdir, file_prefix, fmt="%.3f"): super(FileWriter, self).__init__() if isinstance(outdir, str): outdir = pathlib.Path(outdir) self._outdir = outdir files.create_dir_if_not_exist(outdir=outdir) self._file_prefix = file_prefix - self._fmt = '%.3f' + self._fmt = fmt def write_scalar(self, iteration_num, scalar): outfile = self._outdir / (self._file_prefix + '_scalar.tsv') diff --git a/test_resources/writers/evaluation_results_scalar%.3f.tsv b/test_resources/writers/evaluation_results_scalar%.3f.tsv new file mode 100644 index 00000000..67e92db7 --- /dev/null +++ b/test_resources/writers/evaluation_results_scalar%.3f.tsv @@ -0,0 +1,2 @@ +iteration mean std_dev min max median +1 2.000 1.414 0.000 4.000 2.000 diff --git a/test_resources/writers/evaluation_results_scalar%.5f.tsv b/test_resources/writers/evaluation_results_scalar%.5f.tsv new file mode 100644 index 00000000..d40fa6d7 --- /dev/null +++ b/test_resources/writers/evaluation_results_scalar%.5f.tsv @@ -0,0 +1,2 @@ +iteration mean std_dev min max median +1 2.00000 1.41421 0.00000 4.00000 2.00000 diff --git a/test_resources/writers/evaluation_results_scalar%f.tsv b/test_resources/writers/evaluation_results_scalar%f.tsv new file mode 100644 index 00000000..f307cf9e --- /dev/null +++ b/test_resources/writers/evaluation_results_scalar%f.tsv @@ -0,0 +1,2 @@ +iteration mean std_dev min max median +1 2.000000 1.414214 0.000000 4.000000 2.000000 diff --git a/tests/writers/test_file_writer.py b/tests/writers/test_file_writer.py index 20ea64ec..eab5e44b 100644 --- a/tests/writers/test_file_writer.py +++ b/tests/writers/test_file_writer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import tempfile import numpy as np +import pytest from nnabla_rl.writers.file_writer import FileWriter @@ -62,6 +63,27 @@ def test_write_histogram(self): os.path.join(test_file_dir, 'evaluation_results_histogram.tsv') self._check_same_tsv_file(file_path, test_file_path) + @pytest.mark.parametrize("format", ["%f", "%.3f", "%.5f"]) + def test_data_formatting(self, format): + with tempfile.TemporaryDirectory() as tmpdir: + test_returns = np.arange(5) + test_results = {} + test_results['mean'] = np.mean(test_returns) + test_results['std_dev'] = np.std(test_returns) + test_results['min'] = np.min(test_returns) + test_results['max'] = np.max(test_returns) + test_results['median'] = np.median(test_returns) + + writer = FileWriter(outdir=tmpdir, file_prefix='actual_results', fmt=format) + writer.write_scalar(1, test_results) + + actual_file_path = os.path.join(tmpdir, 'actual_results_scalar.tsv') + + this_file_dir = os.path.dirname(__file__) + expected_file_dir = this_file_dir.replace('tests', 'test_resources') + expected_file_path = os.path.join(expected_file_dir, f'evaluation_results_scalar{format}.tsv') + self._check_same_tsv_file(actual_file_path, expected_file_path) + def _check_same_tsv_file(self, file_path1, file_path2): # check each line with open(file_path1, mode='rt') as data_1, \ @@ -71,5 +93,4 @@ def _check_same_tsv_file(self, file_path1, file_path2): if __name__ == "__main__": - import pytest pytest.main()