diff --git a/modules/core/src/tools/file/vpIoTools.cpp b/modules/core/src/tools/file/vpIoTools.cpp index c6321c1c39..42cbd30975 100644 --- a/modules/core/src/tools/file/vpIoTools.cpp +++ b/modules/core/src/tools/file/vpIoTools.cpp @@ -360,34 +360,35 @@ visp::cnpy::npz_t visp::cnpy::npz_load(std::string fname) if ((local_header[2] != 0x03) || (local_header[3] != 0x04)) { quit = true; } + else { + //read in the variable name + uint16_t name_len = *(uint16_t *)&local_header[26]; + std::string varname(name_len, ' '); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) { + throw std::runtime_error("npz_load: failed fread"); + } - //read in the variable name - uint16_t name_len = *(uint16_t *)&local_header[26]; - std::string varname(name_len, ' '); - size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); - if (vname_res != name_len) { - throw std::runtime_error("npz_load: failed fread"); - } - - //erase the lagging .npy - varname.erase(varname.end()-4, varname.end()); + //erase the lagging .npy + varname.erase(varname.end()-4, varname.end()); - //read in the extra field - uint16_t extra_field_len = *(uint16_t *)&local_header[28]; - if (extra_field_len > 0) { - std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); - if (efield_res != extra_field_len) { - throw std::runtime_error("npz_load: failed fread"); + //read in the extra field + uint16_t extra_field_len = *(uint16_t *)&local_header[28]; + if (extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + if (efield_res != extra_field_len) { + throw std::runtime_error("npz_load: failed fread"); + } } - } - uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); - uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); - uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); - if (compr_method == 0) { arrays[varname] = load_the_npy_file(fp); } - else { arrays[varname] = load_the_npz_array(fp, compr_bytes, uncompr_bytes); } + if (compr_method == 0) { arrays[varname] = load_the_npy_file(fp); } + else { arrays[varname] = load_the_npz_array(fp, compr_bytes, uncompr_bytes); } + } } fclose(fp); @@ -423,33 +424,34 @@ visp::cnpy::NpyArray visp::cnpy::npz_load(std::string fname, std::string varname if ((local_header[2] != 0x03) || (local_header[3] != 0x04)) { quit = true; } + else { + //read in the variable name + uint16_t name_len = *(uint16_t *)&local_header[26]; + std::string vname(name_len, ' '); + size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) { + throw std::runtime_error("npz_load: failed fread"); + } + vname.erase(vname.end()-4, vname.end()); //erase the lagging .npy - //read in the variable name - uint16_t name_len = *(uint16_t *)&local_header[26]; - std::string vname(name_len, ' '); - size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); - if (vname_res != name_len) { - throw std::runtime_error("npz_load: failed fread"); - } - vname.erase(vname.end()-4, vname.end()); //erase the lagging .npy - - //read in the extra field - uint16_t extra_field_len = *(uint16_t *)&local_header[28]; - fseek(fp, extra_field_len, SEEK_CUR); //skip past the extra field + //read in the extra field + uint16_t extra_field_len = *(uint16_t *)&local_header[28]; + fseek(fp, extra_field_len, SEEK_CUR); //skip past the extra field - uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); - uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); - uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); - if (vname == varname) { - NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp, compr_bytes, uncompr_bytes); - fclose(fp); - return array; - } - else { - //skip past the data - uint32_t size = *(uint32_t *)&local_header[22]; - fseek(fp, size, SEEK_CUR); + if (vname == varname) { + NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp, compr_bytes, uncompr_bytes); + fclose(fp); + return array; + } + else { + //skip past the data + uint32_t size = *(uint32_t *)&local_header[22]; + fseek(fp, size, SEEK_CUR); + } } } @@ -2119,16 +2121,16 @@ std::string vpIoTools::toLowerCase(const std::string &input) out += std::tolower(*it); } return out; - } +} - /** - * @brief Return a upper-case version of the string \b input . - * Numbers and special characters stay the same - * - * @param input The input string for which we want to ensure that all the characters are in upper case. - * @return std::string A upper-case version of the string \b input, where - * numbers and special characters stay the same - */ +/** + * @brief Return a upper-case version of the string \b input . + * Numbers and special characters stay the same + * + * @param input The input string for which we want to ensure that all the characters are in upper case. + * @return std::string A upper-case version of the string \b input, where + * numbers and special characters stay the same + */ std::string vpIoTools::toUpperCase(const std::string &input) { std::string out; @@ -2140,16 +2142,16 @@ std::string vpIoTools::toUpperCase(const std::string &input) out += std::toupper(*it); } return out; - } +} - /*! - Returns the absolute path using realpath() on Unix systems or - GetFullPathName() on Windows systems. \return According to realpath() - manual, returns an absolute pathname that names the same file, whose - resolution does not involve '.', '..', or symbolic links for Unix systems. - According to GetFullPathName() documentation, retrieves the full path of the - specified file for Windows systems. - */ +/*! + Returns the absolute path using realpath() on Unix systems or + GetFullPathName() on Windows systems. \return According to realpath() + manual, returns an absolute pathname that names the same file, whose + resolution does not involve '.', '..', or symbolic links for Unix systems. + According to GetFullPathName() documentation, retrieves the full path of the + specified file for Windows systems. + */ std::string vpIoTools::getAbsolutePathname(const std::string &pathname) { diff --git a/modules/core/test/tools/io/testNPZ.cpp b/modules/core/test/tools/io/testNPZ.cpp new file mode 100644 index 0000000000..f384783487 --- /dev/null +++ b/modules/core/test/tools/io/testNPZ.cpp @@ -0,0 +1,161 @@ +/**************************************************************************** + * + * ViSP, open source Visual Servoing Platform software. + * Copyright (C) 2005 - 2024 by Inria. All rights reserved. + * + * This software is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * See the file LICENSE.txt at the root directory of this source + * distribution for additional information about the GNU GPL. + * + * For using ViSP with software that can not be combined with the GNU + * GPL, please contact Inria about acquiring a ViSP Professional + * Edition License. + * + * See https://visp.inria.fr for more information. + * + * This software was developed at: + * Inria Rennes - Bretagne Atlantique + * Campus Universitaire de Beaulieu + * 35042 Rennes Cedex + * France + * + * If you have questions regarding the use of this file, please contact + * Inria at visp@inria.fr + * + * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE + * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. + * + * Description: + * Test visp::cnpy::npz_load() / visp::cnpy::npy_save() functions. + * +*****************************************************************************/ + +#include +#include +#include + +#if defined(VISP_HAVE_CATCH2) && \ + (defined(_WIN32) || (defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)))) && \ + defined(VISP_LITTLE_ENDIAN) +#define CATCH_CONFIG_RUNNER +#include + +#include + +namespace +{ +std::string createTmpDir() +{ + std::string username; + vpIoTools::getUserName(username); + +#if defined(_WIN32) + std::string tmp_dir = "C:/temp/" + username; +#elif (defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) + std::string tmp_dir = "/tmp/" + username; +#endif + std::string directory_filename = tmp_dir + "/testNPZ/"; + + vpIoTools::makeDirectory(directory_filename); + return directory_filename; +} +} + +TEST_CASE("Test visp::cnpy::npy_load/npz_save", "[visp::cnpy I/O]") +{ + std::string directory_filename = createTmpDir(); + REQUIRE(vpIoTools::checkDirectory(directory_filename)); + std::string npz_filename = directory_filename + "/test_npz_read_write.npz"; + + SECTION("Read/Save string data") + { + const std::string save_string = "Open Source Visual Servoing Platform"; + std::vector vec_save_string(save_string.begin(), save_string.end()); + const std::string identifier = "String"; + visp::cnpy::npz_save(npz_filename, identifier, &vec_save_string[0], { vec_save_string.size() }, "w"); + + visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename); + visp::cnpy::NpyArray arr_string_data = npz_data[identifier]; + std::vector vec_arr_string_data = arr_string_data.as_vec(); + // For null-terminated character handling, see: + // https://stackoverflow.com/a/8247804 + // https://stackoverflow.com/a/45491652 + const std::string read_string = std::string(vec_arr_string_data.begin(), vec_arr_string_data.end()); + CHECK(save_string == read_string); + } + + SECTION("Read/Save multi-dimensional array") + { + size_t height = 5, width = 7, channels = 3; + std::vector save_vec; + save_vec.reserve(height*width*channels); + for (size_t i = 0; i < height*width*channels; i++) { + save_vec.push_back(i); + } + + const std::string identifier = "Array"; + visp::cnpy::npz_save(npz_filename, identifier, &save_vec[0], { height, width, channels }, "a"); // append + visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename); + visp::cnpy::NpyArray arr_vec_data = npz_data[identifier]; + std::vector read_vec = arr_vec_data.as_vec(); + + REQUIRE(save_vec.size() == read_vec.size()); + for (size_t i = 0; i < read_vec.size(); i++) { + CHECK(save_vec[i] == read_vec[i]); + } + } + + REQUIRE(vpIoTools::remove(directory_filename)); + REQUIRE(!vpIoTools::checkDirectory(directory_filename)); +} + +// https://en.cppreference.com/w/cpp/types/integer +// https://github.com/catchorg/Catch2/blob/devel/docs/test-cases-and-sections.md#type-parametrised-test-cases +using BasicTypes = std::tuple; +TEMPLATE_LIST_TEST_CASE("Test visp::cnpy::npy_load/npz_save", "[BasicTypes][list]", BasicTypes) +{ + std::string directory_filename = createTmpDir(); + REQUIRE(vpIoTools::checkDirectory(directory_filename)); + std::string npz_filename = directory_filename + "/test_npz_read_write.npz"; + + const std::string identifier = "data"; + TestType save_data = std::numeric_limits::min(); + visp::cnpy::npz_save(npz_filename, identifier, &save_data, { 1 }, "w"); + + visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename); + visp::cnpy::NpyArray arr_data = npz_data[identifier]; + TestType read_data = *arr_data.data(); + CHECK(save_data == read_data); + + save_data = std::numeric_limits::max(); + visp::cnpy::npz_save(npz_filename, identifier, &save_data, { 1 }, "a"); // append + + npz_data = visp::cnpy::npz_load(npz_filename); + arr_data = npz_data[identifier]; + read_data = *arr_data.data(); + CHECK(save_data == read_data); + + REQUIRE(vpIoTools::remove(directory_filename)); + REQUIRE(!vpIoTools::checkDirectory(directory_filename)); +} + +int main(int argc, char *argv[]) +{ + Catch::Session session; // There must be exactly one instance + + // Let Catch (using Clara) parse the command line + session.applyCommandLine(argc, argv); + + int numFailed = session.run(); + + // numFailed is clamped to 255 as some unices only use the lower 8 bits. + // This clamping has already been applied, so just return it here + // You can also do any post run clean-up here + return numFailed; +} +#else +int main() { return EXIT_SUCCESS; } +#endif