Skip to content

Commit

Permalink
Merge pull request #1396 from s-trinh/add_testNPZ
Browse files Browse the repository at this point in the history
Add basic I/O tests for visp::cnpy::npz_load(), visp::cnpy::npz_save() functions
  • Loading branch information
fspindle authored May 13, 2024
2 parents 6591281 + e46bd34 commit 5a34d72
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 65 deletions.
132 changes: 67 additions & 65 deletions modules/core/src/tools/file/vpIoTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,34 +360,35 @@ visp::cnpy::npz_t visp::cnpy::npz_load(std::string fname)
if ((local_header[2] != 0x03) || (local_header[3] != 0x04)) {
quit = true;
}
else {
//read in the variable name
uint16_t name_len = *(uint16_t *)&local_header[26];
std::string varname(name_len, ' ');
size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp);
if (vname_res != name_len) {
throw std::runtime_error("npz_load: failed fread");
}

//read in the variable name
uint16_t name_len = *(uint16_t *)&local_header[26];
std::string varname(name_len, ' ');
size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp);
if (vname_res != name_len) {
throw std::runtime_error("npz_load: failed fread");
}

//erase the lagging .npy
varname.erase(varname.end()-4, varname.end());
//erase the lagging .npy
varname.erase(varname.end()-4, varname.end());

//read in the extra field
uint16_t extra_field_len = *(uint16_t *)&local_header[28];
if (extra_field_len > 0) {
std::vector<char> buff(extra_field_len);
size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp);
if (efield_res != extra_field_len) {
throw std::runtime_error("npz_load: failed fread");
//read in the extra field
uint16_t extra_field_len = *(uint16_t *)&local_header[28];
if (extra_field_len > 0) {
std::vector<char> buff(extra_field_len);
size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp);
if (efield_res != extra_field_len) {
throw std::runtime_error("npz_load: failed fread");
}
}
}

uint16_t compr_method = *reinterpret_cast<uint16_t *>(&local_header[0]+8);
uint32_t compr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+18);
uint32_t uncompr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+22);
uint16_t compr_method = *reinterpret_cast<uint16_t *>(&local_header[0]+8);
uint32_t compr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+18);
uint32_t uncompr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+22);

if (compr_method == 0) { arrays[varname] = load_the_npy_file(fp); }
else { arrays[varname] = load_the_npz_array(fp, compr_bytes, uncompr_bytes); }
if (compr_method == 0) { arrays[varname] = load_the_npy_file(fp); }
else { arrays[varname] = load_the_npz_array(fp, compr_bytes, uncompr_bytes); }
}
}

fclose(fp);
Expand Down Expand Up @@ -423,33 +424,34 @@ visp::cnpy::NpyArray visp::cnpy::npz_load(std::string fname, std::string varname
if ((local_header[2] != 0x03) || (local_header[3] != 0x04)) {
quit = true;
}
else {
//read in the variable name
uint16_t name_len = *(uint16_t *)&local_header[26];
std::string vname(name_len, ' ');
size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp);
if (vname_res != name_len) {
throw std::runtime_error("npz_load: failed fread");
}
vname.erase(vname.end()-4, vname.end()); //erase the lagging .npy

//read in the variable name
uint16_t name_len = *(uint16_t *)&local_header[26];
std::string vname(name_len, ' ');
size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp);
if (vname_res != name_len) {
throw std::runtime_error("npz_load: failed fread");
}
vname.erase(vname.end()-4, vname.end()); //erase the lagging .npy

//read in the extra field
uint16_t extra_field_len = *(uint16_t *)&local_header[28];
fseek(fp, extra_field_len, SEEK_CUR); //skip past the extra field
//read in the extra field
uint16_t extra_field_len = *(uint16_t *)&local_header[28];
fseek(fp, extra_field_len, SEEK_CUR); //skip past the extra field

uint16_t compr_method = *reinterpret_cast<uint16_t *>(&local_header[0]+8);
uint32_t compr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+18);
uint32_t uncompr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+22);
uint16_t compr_method = *reinterpret_cast<uint16_t *>(&local_header[0]+8);
uint32_t compr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+18);
uint32_t uncompr_bytes = *reinterpret_cast<uint32_t *>(&local_header[0]+22);

if (vname == varname) {
NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp, compr_bytes, uncompr_bytes);
fclose(fp);
return array;
}
else {
//skip past the data
uint32_t size = *(uint32_t *)&local_header[22];
fseek(fp, size, SEEK_CUR);
if (vname == varname) {
NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp, compr_bytes, uncompr_bytes);
fclose(fp);
return array;
}
else {
//skip past the data
uint32_t size = *(uint32_t *)&local_header[22];
fseek(fp, size, SEEK_CUR);
}
}
}

Expand Down Expand Up @@ -2119,16 +2121,16 @@ std::string vpIoTools::toLowerCase(const std::string &input)
out += std::tolower(*it);
}
return out;
}
}

/**
* @brief Return a upper-case version of the string \b input .
* Numbers and special characters stay the same
*
* @param input The input string for which we want to ensure that all the characters are in upper case.
* @return std::string A upper-case version of the string \b input, where
* numbers and special characters stay the same
*/
/**
* @brief Return a upper-case version of the string \b input .
* Numbers and special characters stay the same
*
* @param input The input string for which we want to ensure that all the characters are in upper case.
* @return std::string A upper-case version of the string \b input, where
* numbers and special characters stay the same
*/
std::string vpIoTools::toUpperCase(const std::string &input)
{
std::string out;
Expand All @@ -2140,16 +2142,16 @@ std::string vpIoTools::toUpperCase(const std::string &input)
out += std::toupper(*it);
}
return out;
}
}

/*!
Returns the absolute path using realpath() on Unix systems or
GetFullPathName() on Windows systems. \return According to realpath()
manual, returns an absolute pathname that names the same file, whose
resolution does not involve '.', '..', or symbolic links for Unix systems.
According to GetFullPathName() documentation, retrieves the full path of the
specified file for Windows systems.
*/
/*!
Returns the absolute path using realpath() on Unix systems or
GetFullPathName() on Windows systems. \return According to realpath()
manual, returns an absolute pathname that names the same file, whose
resolution does not involve '.', '..', or symbolic links for Unix systems.
According to GetFullPathName() documentation, retrieves the full path of the
specified file for Windows systems.
*/
std::string vpIoTools::getAbsolutePathname(const std::string &pathname)
{

Expand Down
161 changes: 161 additions & 0 deletions modules/core/test/tools/io/testNPZ.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/****************************************************************************
*
* ViSP, open source Visual Servoing Platform software.
* Copyright (C) 2005 - 2024 by Inria. All rights reserved.
*
* This software is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
* See the file LICENSE.txt at the root directory of this source
* distribution for additional information about the GNU GPL.
*
* For using ViSP with software that can not be combined with the GNU
* GPL, please contact Inria about acquiring a ViSP Professional
* Edition License.
*
* See https://visp.inria.fr for more information.
*
* This software was developed at:
* Inria Rennes - Bretagne Atlantique
* Campus Universitaire de Beaulieu
* 35042 Rennes Cedex
* France
*
* If you have questions regarding the use of this file, please contact
* Inria at [email protected]
*
* This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
* WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
*
* Description:
* Test visp::cnpy::npz_load() / visp::cnpy::npy_save() functions.
*
*****************************************************************************/

#include <iostream>
#include <visp3/core/vpConfig.h>
#include <visp3/core/vpEndian.h>

#if defined(VISP_HAVE_CATCH2) && \
(defined(_WIN32) || (defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)))) && \
defined(VISP_LITTLE_ENDIAN)
#define CATCH_CONFIG_RUNNER
#include <catch.hpp>

#include <visp3/core/vpIoTools.h>

namespace
{
std::string createTmpDir()
{
std::string username;
vpIoTools::getUserName(username);

#if defined(_WIN32)
std::string tmp_dir = "C:/temp/" + username;
#elif (defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)))
std::string tmp_dir = "/tmp/" + username;
#endif
std::string directory_filename = tmp_dir + "/testNPZ/";

vpIoTools::makeDirectory(directory_filename);
return directory_filename;
}
}

TEST_CASE("Test visp::cnpy::npy_load/npz_save", "[visp::cnpy I/O]")
{
std::string directory_filename = createTmpDir();
REQUIRE(vpIoTools::checkDirectory(directory_filename));
std::string npz_filename = directory_filename + "/test_npz_read_write.npz";

SECTION("Read/Save string data")
{
const std::string save_string = "Open Source Visual Servoing Platform";
std::vector<char> vec_save_string(save_string.begin(), save_string.end());
const std::string identifier = "String";
visp::cnpy::npz_save(npz_filename, identifier, &vec_save_string[0], { vec_save_string.size() }, "w");

visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename);
visp::cnpy::NpyArray arr_string_data = npz_data[identifier];
std::vector<char> vec_arr_string_data = arr_string_data.as_vec<char>();
// For null-terminated character handling, see:
// https://stackoverflow.com/a/8247804
// https://stackoverflow.com/a/45491652
const std::string read_string = std::string(vec_arr_string_data.begin(), vec_arr_string_data.end());
CHECK(save_string == read_string);
}

SECTION("Read/Save multi-dimensional array")
{
size_t height = 5, width = 7, channels = 3;
std::vector<int> save_vec;
save_vec.reserve(height*width*channels);
for (size_t i = 0; i < height*width*channels; i++) {
save_vec.push_back(i);
}

const std::string identifier = "Array";
visp::cnpy::npz_save(npz_filename, identifier, &save_vec[0], { height, width, channels }, "a"); // append
visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename);
visp::cnpy::NpyArray arr_vec_data = npz_data[identifier];
std::vector<int> read_vec = arr_vec_data.as_vec<int>();

REQUIRE(save_vec.size() == read_vec.size());
for (size_t i = 0; i < read_vec.size(); i++) {
CHECK(save_vec[i] == read_vec[i]);
}
}

REQUIRE(vpIoTools::remove(directory_filename));
REQUIRE(!vpIoTools::checkDirectory(directory_filename));
}

// https://en.cppreference.com/w/cpp/types/integer
// https://github.com/catchorg/Catch2/blob/devel/docs/test-cases-and-sections.md#type-parametrised-test-cases
using BasicTypes = std::tuple<uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double>;
TEMPLATE_LIST_TEST_CASE("Test visp::cnpy::npy_load/npz_save", "[BasicTypes][list]", BasicTypes)
{
std::string directory_filename = createTmpDir();
REQUIRE(vpIoTools::checkDirectory(directory_filename));
std::string npz_filename = directory_filename + "/test_npz_read_write.npz";

const std::string identifier = "data";
TestType save_data = std::numeric_limits<TestType>::min();
visp::cnpy::npz_save(npz_filename, identifier, &save_data, { 1 }, "w");

visp::cnpy::npz_t npz_data = visp::cnpy::npz_load(npz_filename);
visp::cnpy::NpyArray arr_data = npz_data[identifier];
TestType read_data = *arr_data.data<TestType>();
CHECK(save_data == read_data);

save_data = std::numeric_limits<TestType>::max();
visp::cnpy::npz_save(npz_filename, identifier, &save_data, { 1 }, "a"); // append

npz_data = visp::cnpy::npz_load(npz_filename);
arr_data = npz_data[identifier];
read_data = *arr_data.data<TestType>();
CHECK(save_data == read_data);

REQUIRE(vpIoTools::remove(directory_filename));
REQUIRE(!vpIoTools::checkDirectory(directory_filename));
}

int main(int argc, char *argv[])
{
Catch::Session session; // There must be exactly one instance

// Let Catch (using Clara) parse the command line
session.applyCommandLine(argc, argv);

int numFailed = session.run();

// numFailed is clamped to 255 as some unices only use the lower 8 bits.
// This clamping has already been applied, so just return it here
// You can also do any post run clean-up here
return numFailed;
}
#else
int main() { return EXIT_SUCCESS; }
#endif

0 comments on commit 5a34d72

Please sign in to comment.