Skip to content

Commit

Permalink
fix enum flag from cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
skallweitNV committed Sep 26, 2024
1 parent efb22ed commit 815e30c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
17 changes: 16 additions & 1 deletion src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
scope.attr(name) = result;
result.attr("__doc__") = ed->docstr ? str(ed->docstr) : none();

result.attr("__str__") = enum_mod.attr("Enum").attr("__str__");
result.attr("__str__") = enum_mod.attr(is_flag ? factory_name : "Enum").attr("__str__");
result.attr("__repr__") = result.attr("__str__");

type_init_data *t = new type_init_data();
Expand Down Expand Up @@ -267,6 +267,21 @@ PyObject *enum_from_cpp(const std::type_info *tp, int64_t key) noexcept {
return value;
}

if ((t->flags & (uint32_t) enum_flags::is_flag) != 0) {
handle enum_tp(t->type_py);

object val;
if (t->flags & (uint32_t) enum_flags::is_signed)
val = steal(PyLong_FromLongLong((long long) key));
else
val = steal(PyLong_FromUnsignedLongLong((unsigned long long) key));

object result;
result = enum_tp.attr("__new__")(enum_tp, val);
Py_INCREF(result.ptr());
return result.ptr();
}

if (t->flags & (uint32_t) enum_flags::is_signed)
PyErr_Format(PyExc_ValueError, "%lli is not a valid %s.",
(long long) key, t->name);
Expand Down
4 changes: 2 additions & 2 deletions tests/test_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ NB_MODULE(test_enum_ext, m) {
m.def("from_enum", [](Enum value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Enum) value; });
m.def("from_enum", [](Flag value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Flag) value; });
m.def("to_flag", [](uint32_t value) { return (Flag) value; });
m.def("from_enum", [](SEnum value) { return (int32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint64_t value) { return (UnsignedFlag) value; });
m.def("to_unsigned_flag", [](uint64_t value) { return (UnsignedFlag) value; });
m.def("from_enum", [](UnsignedFlag value) { return (uint64_t) value; }, nb::arg().noconvert());
m.def("from_enum_implicit", [](Enum value) { return (uint32_t) value; });

Expand Down
5 changes: 5 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def test06_enum_flag():
assert t.from_enum(t.UnsignedFlag.B) == 2
assert t.from_enum(t.UnsignedFlag.All) == 0xffffffffffffffff

assert t.to_flag(1) == t.Flag.A
assert t.to_flag(2) == t.Flag.B
assert t.to_flag(4) == t.Flag.C
assert t.to_flag(5) == (t.Flag.A | t.Flag.C)

def test09_enum_methods():
assert t.Item1.my_value == 0 and t.Item2.my_value == 1
assert t.Item1.get_value() == 0 and t.Item2.get_value() == 1
Expand Down
7 changes: 2 additions & 5 deletions tests/test_enum_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,8 @@ def from_enum_implicit(arg: Enum, /) -> int: ...
@overload
def from_enum_implicit(arg: Flag, /) -> int: ...

@overload
def to_enum(arg: int, /) -> Enum: ...

@overload
def to_enum(arg: int, /) -> Flag: ...
def to_flag(arg: int, /) -> Flag: ...

@overload
def to_enum(arg: int, /) -> UnsignedFlag: ...
def to_unsigned_flag(arg: int, /) -> UnsignedFlag: ...
2 changes: 0 additions & 2 deletions tests/test_ndarray_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ def ret_numpy_const_ref() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2

def ret_numpy_const_ref_f() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), order='F', writable=False)]: ...

def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))]: ...

def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...

def ret_tensorflow() -> tensorflow.python.framework.ops.EagerTensor[dtype=float32, shape=(2, 4)]: ...
Expand Down

0 comments on commit 815e30c

Please sign in to comment.