Skip to content

Commit

Permalink
PYTHON: Expose methods to convert numeric types. (#1133)
Browse files Browse the repository at this point in the history
- Expose method to convert numpy types to NumericType.
- Expose method to convert DataType to NumericType.

---------

Signed-off-by: Joey Kleingers <[email protected]>
  • Loading branch information
joeykleingers authored Nov 23, 2024
1 parent bae5244 commit b358403
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ struct fmt::formatter<nx::core::Warning>
}
};

/**
* @brief Equivalent to lhs.__eq__(rhs) in python
* @param lhs
* @param rhs
* @return bool
*/
bool PyIsEqual(py::handle lhs, py::handle rhs)
{
return (lhs.attr("__eq__")(rhs)).cast<bool>();
}

template <class ParameterT>
void PyInsertLinkableParameter(Parameters& self, const ParameterT& param)
{
Expand Down Expand Up @@ -456,47 +467,47 @@ PYBIND11_MODULE(simplnx, mod)
mod.def(
"convert_np_dtype_to_datatype",
[](const py::dtype& dtype) {
if((dtype.attr("__eq__")(py::dtype::of<int8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int8>()))
{
return DataType::int8;
}
if((dtype.attr("__eq__")(py::dtype::of<uint8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint8>()))
{
return DataType::uint8;
}
if((dtype.attr("__eq__")(py::dtype::of<int16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int16>()))
{
return DataType::int16;
}
if((dtype.attr("__eq__")(py::dtype::of<uint16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint16>()))
{
return DataType::uint16;
}
if((dtype.attr("__eq__")(py::dtype::of<int32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int32>()))
{
return DataType::int32;
}
if((dtype.attr("__eq__")(py::dtype::of<uint32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint32>()))
{
return DataType::uint32;
}
if((dtype.attr("__eq__")(py::dtype::of<int64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int64>()))
{
return DataType::int64;
}
if((dtype.attr("__eq__")(py::dtype::of<uint64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint64>()))
{
return DataType::uint64;
}
if((dtype.attr("__eq__")(py::dtype::of<float32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float32>()))
{
return DataType::float32;
}
if((dtype.attr("__eq__")(py::dtype::of<float64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float64>()))
{
return DataType::float64;
}
if((dtype.attr("__eq__")(py::dtype::of<bool>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<bool>()))
{
return DataType::boolean;
}
Expand All @@ -506,6 +517,55 @@ PYBIND11_MODULE(simplnx, mod)
},
"Convert numpy dtype to simplnx DataType", "dtype"_a);

mod.def(
"convert_np_dtype_to_numeric_type",
[](const py::dtype& dtype) {
if(PyIsEqual(dtype, py::dtype::of<int8>()))
{
return NumericType::int8;
}
if(PyIsEqual(dtype, py::dtype::of<uint8>()))
{
return NumericType::uint8;
}
if(PyIsEqual(dtype, py::dtype::of<int16>()))
{
return NumericType::int16;
}
if(PyIsEqual(dtype, py::dtype::of<uint16>()))
{
return NumericType::uint16;
}
if(PyIsEqual(dtype, py::dtype::of<int32>()))
{
return NumericType::int32;
}
if(PyIsEqual(dtype, py::dtype::of<uint32>()))
{
return NumericType::uint32;
}
if(PyIsEqual(dtype, py::dtype::of<int64>()))
{
return NumericType::int64;
}
if(PyIsEqual(dtype, py::dtype::of<uint64>()))
{
return NumericType::uint64;
}
if(PyIsEqual(dtype, py::dtype::of<float32>()))
{
return NumericType::float32;
}
if(PyIsEqual(dtype, py::dtype::of<float64>()))
{
return NumericType::float64;
}

std::string dtypeStr = py::str(static_cast<py::object>(dtype));
throw std::invalid_argument(fmt::format("Unable to convert dtype to NumericType: Unsupported dtype '{}'.", dtypeStr));
},
"Convert numpy dtype to simplnx NumericType", "dtype"_a);

py::enum_<ArrayHandlingType> arrayHandlingType(mod, "ArrayHandlingType");
arrayHandlingType.value("Copy", ArrayHandlingType::Copy);
arrayHandlingType.value("Move", ArrayHandlingType::Move);
Expand Down Expand Up @@ -1530,6 +1590,7 @@ PYBIND11_MODULE(simplnx, mod)
mod.def("get_all_data_types", &GetAllDataTypes);

mod.def("convert_numeric_type_to_data_type", &ConvertNumericTypeToDataType);
mod.def("convert_data_type_to_numeric_type", &ConvertDataTypeToNumericType);

mod.def("get_filters", [corePlugin]() {
auto filterHandles = corePlugin->getFilterHandles();
Expand Down

0 comments on commit b358403

Please sign in to comment.