From 228034ba4594dccc2ba43ed8db26d6d4969c5cbb Mon Sep 17 00:00:00 2001 From: Joey Kleingers Date: Wed, 20 Nov 2024 13:30:26 -0500 Subject: [PATCH] Implement PR review suggestions. Signed-off-by: Joey Kleingers --- .../SimplnxCore/wrapping/python/simplnxpy.cpp | 72 ++++++++++--------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp b/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp index 75588f9d8c..59ecbf28f1 100644 --- a/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp +++ b/src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp @@ -121,6 +121,17 @@ struct fmt::formatter } }; +/** + * @brief Equivalent to lhs.__eq__(rhs) in python + * @param lhs + * @param rhs + * @return bool + */ +bool PyIsEqual(py::handle lhs, py::handle rhs) +{ + return (lhs.attr("__eq__")(rhs)).cast(); +} + template void PyInsertLinkableParameter(Parameters& self, const ParameterT& param) { @@ -456,47 +467,47 @@ PYBIND11_MODULE(simplnx, mod) mod.def( "convert_np_dtype_to_datatype", [](const py::dtype& dtype) { - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::int64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::uint64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::float32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::float64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return DataType::boolean; } @@ -507,50 +518,50 @@ PYBIND11_MODULE(simplnx, mod) "Convert numpy dtype to simplnx DataType", "dtype"_a); mod.def( - "convert_np_dtype_to_numerictype", + "convert_np_dtype_to_numeric_type", [](const py::dtype& dtype) { - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::int8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::uint8; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::int16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::uint16; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::int32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::uint32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::int64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::uint64; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::float32; } - if((dtype.attr("__eq__")(py::dtype::of())).cast()) + if(PyIsEqual(dtype, py::dtype::of())) { return NumericType::float64; } - std::string dtypeStr = py::str(static_cast(dtype)); + std::string dtypeStr = py::str(dtype); throw std::invalid_argument(fmt::format("Unable to convert dtype to NumericType: Unsupported dtype '{}'.", dtypeStr)); }, "Convert numpy dtype to simplnx NumericType", "dtype"_a); @@ -634,8 +645,7 @@ PYBIND11_MODULE(simplnx, mod) parameters.def("insert_linkable_parameter", &PyInsertLinkableParameter); parameters.def("link_parameters", [](Parameters& self, std::string groupKey, std::string childKey, BoolParameter::ValueType value) { self.linkParameters(groupKey, childKey, value); }); parameters.def("link_parameters", [](Parameters& self, std::string groupKey, std::string childKey, ChoicesParameter::ValueType value) { self.linkParameters(groupKey, childKey, value); }); - parameters.def( - "__getitem__", [](Parameters& self, std::string_view key) { return self.at(key).get(); }, py::return_value_policy::reference_internal); + parameters.def("__getitem__", [](Parameters& self, std::string_view key) { return self.at(key).get(); }, py::return_value_policy::reference_internal); py::class_> iArrayThreshold(mod, "IArrayThreshold"); @@ -1473,12 +1483,10 @@ PYBIND11_MODULE(simplnx, mod) "path"_a); pipeline.def_property("name", &Pipeline::getName, &Pipeline::setName); pipeline.def("execute", &ExecutePipeline); - pipeline.def( - "__getitem__", [](Pipeline& self, Pipeline::index_type index) { return self.at(index); }, py::return_value_policy::reference_internal); + pipeline.def("__getitem__", [](Pipeline& self, Pipeline::index_type index) { return self.at(index); }, py::return_value_policy::reference_internal); pipeline.def("__len__", &Pipeline::size); pipeline.def("size", &Pipeline::size); - pipeline.def( - "__iter__", [](Pipeline& self) { return py::make_iterator(self.begin(), self.end()); }, py::keep_alive<0, 1>()); + pipeline.def("__iter__", [](Pipeline& self) { return py::make_iterator(self.begin(), self.end()); }, py::keep_alive<0, 1>()); pipeline.def( "insert", [internals](Pipeline& self, Pipeline::index_type index, const IFilter& filter, const py::dict& args) { @@ -1492,10 +1500,8 @@ PYBIND11_MODULE(simplnx, mod) pipeline.def("remove", &Pipeline::removeAt, "index"_a); pipelineFilter.def("get_args", [internals](PipelineFilter& self) { return ConvertArgsToDict(*internals, self.getParameters(), self.getArguments()); }); - pipelineFilter.def( - "set_args", [internals](PipelineFilter& self, py::dict& args) { self.setArguments(ConvertDictToArgs(*internals, self.getParameters(), args)); }, "args"_a); - pipelineFilter.def( - "get_filter", [](PipelineFilter& self) { return self.getFilter(); }, py::return_value_policy::reference_internal); + pipelineFilter.def("set_args", [internals](PipelineFilter& self, py::dict& args) { self.setArguments(ConvertDictToArgs(*internals, self.getParameters(), args)); }, "args"_a); + pipelineFilter.def("get_filter", [](PipelineFilter& self) { return self.getFilter(); }, py::return_value_policy::reference_internal); pipelineFilter.def( "name", [](const PipelineFilter& self) {