diff --git a/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp b/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp index 65ec0eb33d..d977dc8c12 100644 --- a/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp +++ b/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp @@ -121,6 +121,17 @@ struct fmt::formatter } }; +/** + * @brief Equivalent to lhs.__eq__(rhs) in python + * @param lhs + * @param rhs + * @return bool + */ +bool PyIsEqual(py::handle lhs, py::handle rhs) +{ + return (lhs.attr("__eq__")(rhs)).cast(); +} + template void PyInsertLinkableParameter(Parameters& self, const ParameterT& param) { @@ -456,47 +467,47 @@ PYBIND11_MODULE(simplnx, mod) mod.def( "convert_np_dtype_to_datatype", [](const py::dtype& dtype) { - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::float32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::float64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::boolean; } @@ -506,6 +517,55 @@ PYBIND11_MODULE(simplnx, mod) }, "Convert numpy dtype to simplnx DataType", "dtype"_a); + mod.def( + "convert_np_dtype_to_numeric_type", + [](const py::dtype& dtype) { + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::int8; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::uint8; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::int16; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::uint16; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::int32; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::uint32; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::int64; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::uint64; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::float32; + } + if(PyIsEqual(dtype, py::dtype::of())) + { + return NumericType::float64; + } + + std::string dtypeStr = py::str(static_cast(dtype)); + throw std::invalid_argument(fmt::format("Unable to convert dtype to NumericType: Unsupported dtype '{}'.", dtypeStr)); + }, + "Convert numpy dtype to simplnx NumericType", "dtype"_a); + py::enum_ arrayHandlingType(mod, "ArrayHandlingType"); arrayHandlingType.value("Copy", ArrayHandlingType::Copy); arrayHandlingType.value("Move", ArrayHandlingType::Move); @@ -1530,6 +1590,7 @@ PYBIND11_MODULE(simplnx, mod) mod.def("get_all_data_types", &GetAllDataTypes); mod.def("convert_numeric_type_to_data_type", &ConvertNumericTypeToDataType); + mod.def("convert_data_type_to_numeric_type", &ConvertDataTypeToNumericType); mod.def("get_filters", [corePlugin]() { auto filterHandles = corePlugin->getFilterHandles();