From a748c05ea55c666509fb2b3e8c16a4b8f43b474c Mon Sep 17 00:00:00 2001 From: Simon Kallweit <64953474+skallweitNV@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:31:40 +0200 Subject: [PATCH] Fix enum flag from cpp conversion, improve stringification (#732) --- src/nb_enum.cpp | 18 ++++++++++++++++-- tests/test_enum.cpp | 4 ++-- tests/test_enum.py | 9 +++++++++ tests/test_enum_ext.pyi.ref | 7 ++----- tests/test_ndarray_ext.pyi.ref | 2 -- 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp index 425008c8..acde6f54 100644 --- a/src/nb_enum.cpp +++ b/src/nb_enum.cpp @@ -71,7 +71,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept { scope.attr(name) = result; result.attr("__doc__") = ed->docstr ? str(ed->docstr) : none(); - result.attr("__str__") = enum_mod.attr("Enum").attr("__str__"); + result.attr("__str__") = enum_mod.attr(is_flag ? factory_name : "Enum").attr("__str__"); result.attr("__repr__") = result.attr("__str__"); type_init_data *t = new type_init_data(); @@ -267,12 +267,26 @@ PyObject *enum_from_cpp(const std::type_info *tp, int64_t key) noexcept { return value; } - if (t->flags & (uint32_t) enum_flags::is_signed) + uint32_t flags = t->flags; + if ((flags & (uint32_t) enum_flags::is_flag) != 0) { + handle enum_tp(t->type_py); + + object val; + if (flags & (uint32_t) enum_flags::is_signed) + val = steal(PyLong_FromLongLong((long long) key)); + else + val = steal(PyLong_FromUnsignedLongLong((unsigned long long) key)); + + return enum_tp.attr("__new__")(enum_tp, val).release().ptr(); + } + + if (flags & (uint32_t) enum_flags::is_signed) PyErr_Format(PyExc_ValueError, "%lli is not a valid %s.", (long long) key, t->name); else PyErr_Format(PyExc_ValueError, "%llu is not a valid %s.", (unsigned long long) key, t->name); + return nullptr; } diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 5bef7fda..dfafb954 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -49,9 +49,9 @@ NB_MODULE(test_enum_ext, m) { m.def("from_enum", [](Enum value) { return (uint32_t) value; }, nb::arg().noconvert()); m.def("to_enum", [](uint32_t value) { return (Enum) value; }); m.def("from_enum", [](Flag value) { return (uint32_t) value; }, nb::arg().noconvert()); - m.def("to_enum", [](uint32_t value) { return (Flag) value; }); + m.def("to_flag", [](uint32_t value) { return (Flag) value; }); m.def("from_enum", [](SEnum value) { return (int32_t) value; }, nb::arg().noconvert()); - m.def("to_enum", [](uint64_t value) { return (UnsignedFlag) value; }); + m.def("to_unsigned_flag", [](uint64_t value) { return (UnsignedFlag) value; }); m.def("from_enum", [](UnsignedFlag value) { return (uint64_t) value; }, nb::arg().noconvert()); m.def("from_enum_implicit", [](Enum value) { return (uint32_t) value; }); diff --git a/tests/test_enum.py b/tests/test_enum.py index ac75e55f..e1736eaf 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -144,6 +144,10 @@ def test06_enum_flag(): assert str(t.Flag.B) == 'Flag.B' assert repr(t.Flag.C) == 'Flag.C' assert str(t.Flag.C) == 'Flag.C' + assert repr(t.Flag.A | t.Flag.B) in ['Flag.A|B', 'Flag.B|A'] + assert str(t.Flag.A | t.Flag.B) in ['Flag.A|B', 'Flag.B|A'] + assert repr(t.Flag.A | t.Flag.B | t.Flag.C) in ['Flag.A|B|C', 'Flag.C|B|A'] + assert str(t.Flag.A | t.Flag.B | t.Flag.C) in ['Flag.A|B|C', 'Flag.C|B|A'] # Flag membership tests assert (t.Flag(1) | t.Flag(2)).value == 3 @@ -170,6 +174,11 @@ def test06_enum_flag(): assert t.from_enum(t.UnsignedFlag.B) == 2 assert t.from_enum(t.UnsignedFlag.All) == 0xffffffffffffffff + assert t.to_flag(1) == t.Flag.A + assert t.to_flag(2) == t.Flag.B + assert t.to_flag(4) == t.Flag.C + assert t.to_flag(5) == (t.Flag.A | t.Flag.C) + def test09_enum_methods(): assert t.Item1.my_value == 0 and t.Item2.my_value == 1 assert t.Item1.get_value() == 0 and t.Item2.get_value() == 1 diff --git a/tests/test_enum_ext.pyi.ref b/tests/test_enum_ext.pyi.ref index dbe89957..37844391 100644 --- a/tests/test_enum_ext.pyi.ref +++ b/tests/test_enum_ext.pyi.ref @@ -100,11 +100,8 @@ def from_enum_implicit(arg: Enum, /) -> int: ... @overload def from_enum_implicit(arg: Flag, /) -> int: ... -@overload def to_enum(arg: int, /) -> Enum: ... -@overload -def to_enum(arg: int, /) -> Flag: ... +def to_flag(arg: int, /) -> Flag: ... -@overload -def to_enum(arg: int, /) -> UnsignedFlag: ... +def to_unsigned_flag(arg: int, /) -> UnsignedFlag: ... diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index ad2d87e5..b90384d0 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -170,8 +170,6 @@ def ret_numpy_const_ref() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2 def ret_numpy_const_ref_f() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), order='F', writable=False)]: ... -def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))]: ... - def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... def ret_tensorflow() -> tensorflow.python.framework.ops.EagerTensor[dtype=float32, shape=(2, 4)]: ...