Skip to content

Commit

Permalink
Implement PR review suggestions.
Browse files Browse the repository at this point in the history
Signed-off-by: Joey Kleingers <[email protected]>
  • Loading branch information
joeykleingers authored and imikejackson committed Nov 22, 2024
1 parent 0dc07c4 commit 228034b
Showing 1 changed file with 39 additions and 33 deletions.
72 changes: 39 additions & 33 deletions src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ struct fmt::formatter<nx::core::Warning>
}
};

/**
* @brief Equivalent to lhs.__eq__(rhs) in python
* @param lhs
* @param rhs
* @return bool
*/
bool PyIsEqual(py::handle lhs, py::handle rhs)
{
return (lhs.attr("__eq__")(rhs)).cast<bool>();
}

template <class ParameterT>
void PyInsertLinkableParameter(Parameters& self, const ParameterT& param)
{
Expand Down Expand Up @@ -456,47 +467,47 @@ PYBIND11_MODULE(simplnx, mod)
mod.def(
"convert_np_dtype_to_datatype",
[](const py::dtype& dtype) {
if((dtype.attr("__eq__")(py::dtype::of<int8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int8>()))
{
return DataType::int8;
}
if((dtype.attr("__eq__")(py::dtype::of<uint8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint8>()))
{
return DataType::uint8;
}
if((dtype.attr("__eq__")(py::dtype::of<int16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int16>()))
{
return DataType::int16;
}
if((dtype.attr("__eq__")(py::dtype::of<uint16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint16>()))
{
return DataType::uint16;
}
if((dtype.attr("__eq__")(py::dtype::of<int32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int32>()))
{
return DataType::int32;
}
if((dtype.attr("__eq__")(py::dtype::of<uint32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint32>()))
{
return DataType::uint32;
}
if((dtype.attr("__eq__")(py::dtype::of<int64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int64>()))
{
return DataType::int64;
}
if((dtype.attr("__eq__")(py::dtype::of<uint64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint64>()))
{
return DataType::uint64;
}
if((dtype.attr("__eq__")(py::dtype::of<float32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float32>()))
{
return DataType::float32;
}
if((dtype.attr("__eq__")(py::dtype::of<float64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float64>()))
{
return DataType::float64;
}
if((dtype.attr("__eq__")(py::dtype::of<bool>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<bool>()))
{
return DataType::boolean;
}
Expand All @@ -507,50 +518,50 @@ PYBIND11_MODULE(simplnx, mod)
"Convert numpy dtype to simplnx DataType", "dtype"_a);

mod.def(
"convert_np_dtype_to_numerictype",
"convert_np_dtype_to_numeric_type",
[](const py::dtype& dtype) {
if((dtype.attr("__eq__")(py::dtype::of<int8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int8>()))
{
return NumericType::int8;
}
if((dtype.attr("__eq__")(py::dtype::of<uint8>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint8>()))
{
return NumericType::uint8;
}
if((dtype.attr("__eq__")(py::dtype::of<int16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int16>()))
{
return NumericType::int16;
}
if((dtype.attr("__eq__")(py::dtype::of<uint16>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint16>()))
{
return NumericType::uint16;
}
if((dtype.attr("__eq__")(py::dtype::of<int32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int32>()))
{
return NumericType::int32;
}
if((dtype.attr("__eq__")(py::dtype::of<uint32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint32>()))
{
return NumericType::uint32;
}
if((dtype.attr("__eq__")(py::dtype::of<int64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<int64>()))
{
return NumericType::int64;
}
if((dtype.attr("__eq__")(py::dtype::of<uint64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<uint64>()))
{
return NumericType::uint64;
}
if((dtype.attr("__eq__")(py::dtype::of<float32>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float32>()))
{
return NumericType::float32;
}
if((dtype.attr("__eq__")(py::dtype::of<float64>())).cast<bool>())
if(PyIsEqual(dtype, py::dtype::of<float64>()))
{
return NumericType::float64;
}

std::string dtypeStr = py::str(static_cast<py::object>(dtype));
std::string dtypeStr = py::str(dtype);

Check failure on line 564 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / build (ubuntu-22.04, g++-11)

call of overloaded ‘str(const pybind11::dtype&)’ is ambiguous
throw std::invalid_argument(fmt::format("Unable to convert dtype to NumericType: Unsupported dtype '{}'.", dtypeStr));
},
"Convert numpy dtype to simplnx NumericType", "dtype"_a);
Expand Down Expand Up @@ -634,8 +645,7 @@ PYBIND11_MODULE(simplnx, mod)
parameters.def("insert_linkable_parameter", &PyInsertLinkableParameter<ChoicesParameter>);
parameters.def("link_parameters", [](Parameters& self, std::string groupKey, std::string childKey, BoolParameter::ValueType value) { self.linkParameters(groupKey, childKey, value); });
parameters.def("link_parameters", [](Parameters& self, std::string groupKey, std::string childKey, ChoicesParameter::ValueType value) { self.linkParameters(groupKey, childKey, value); });
parameters.def(
"__getitem__", [](Parameters& self, std::string_view key) { return self.at(key).get(); }, py::return_value_policy::reference_internal);
parameters.def("__getitem__", [](Parameters& self, std::string_view key) { return self.at(key).get(); }, py::return_value_policy::reference_internal);

Check failure on line 648 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / clang_format_pr

code should be clang-formatted [-Wclang-format-violations]

py::class_<IArrayThreshold, std::shared_ptr<IArrayThreshold>> iArrayThreshold(mod, "IArrayThreshold");

Expand Down Expand Up @@ -1473,12 +1483,10 @@ PYBIND11_MODULE(simplnx, mod)
"path"_a);
pipeline.def_property("name", &Pipeline::getName, &Pipeline::setName);
pipeline.def("execute", &ExecutePipeline);
pipeline.def(
"__getitem__", [](Pipeline& self, Pipeline::index_type index) { return self.at(index); }, py::return_value_policy::reference_internal);
pipeline.def("__getitem__", [](Pipeline& self, Pipeline::index_type index) { return self.at(index); }, py::return_value_policy::reference_internal);

Check failure on line 1486 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / clang_format_pr

code should be clang-formatted [-Wclang-format-violations]
pipeline.def("__len__", &Pipeline::size);
pipeline.def("size", &Pipeline::size);
pipeline.def(
"__iter__", [](Pipeline& self) { return py::make_iterator(self.begin(), self.end()); }, py::keep_alive<0, 1>());
pipeline.def("__iter__", [](Pipeline& self) { return py::make_iterator(self.begin(), self.end()); }, py::keep_alive<0, 1>());

Check failure on line 1489 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / clang_format_pr

code should be clang-formatted [-Wclang-format-violations]
pipeline.def(
"insert",
[internals](Pipeline& self, Pipeline::index_type index, const IFilter& filter, const py::dict& args) {
Expand All @@ -1492,10 +1500,8 @@ PYBIND11_MODULE(simplnx, mod)
pipeline.def("remove", &Pipeline::removeAt, "index"_a);

pipelineFilter.def("get_args", [internals](PipelineFilter& self) { return ConvertArgsToDict(*internals, self.getParameters(), self.getArguments()); });
pipelineFilter.def(
"set_args", [internals](PipelineFilter& self, py::dict& args) { self.setArguments(ConvertDictToArgs(*internals, self.getParameters(), args)); }, "args"_a);
pipelineFilter.def(
"get_filter", [](PipelineFilter& self) { return self.getFilter(); }, py::return_value_policy::reference_internal);
pipelineFilter.def("set_args", [internals](PipelineFilter& self, py::dict& args) { self.setArguments(ConvertDictToArgs(*internals, self.getParameters(), args)); }, "args"_a);

Check failure on line 1503 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / clang_format_pr

code should be clang-formatted [-Wclang-format-violations]
pipelineFilter.def("get_filter", [](PipelineFilter& self) { return self.getFilter(); }, py::return_value_policy::reference_internal);

Check failure on line 1504 in src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

View workflow job for this annotation

GitHub Actions / clang_format_pr

code should be clang-formatted [-Wclang-format-violations]
pipelineFilter.def(
"name",
[](const PipelineFilter& self) {
Expand Down

0 comments on commit 228034b

Please sign in to comment.