Skip to content

Commit

Permalink
Perform two passes in the variant caster
Browse files Browse the repository at this point in the history
When converting a Python object that wraps a `T` to a
`std::variant<U,T>`, where `U` is implicitly convertible
from `T`, the variant caster will cast to the `U` (even though the
Python object is definitely a `T`)

The included test case demonstrates the issue. pybind11 does two pass
conversion in this case. One can work around the issue by using
`noconvert()` on the argument. But it seems that it would be a
friendlier default to make type conversions in a variant work as
expected.
  • Loading branch information
tjstum committed Oct 24, 2024
1 parent fd22b8c commit b246884
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ case, both modules must use the same nanobind ABI version, or they will be
isolated from each other. Releases that don't explicitly mention an ABI version
below inherit that of the preceding release.

Version TBD (unreleased)
------------------------

- The ``std::variant`` type_caster now does two passes when converting from Python.
The first pass is done without implicit conversions. This fixes an issue where
``std::variant<U, T>`` might cast a Python object wrapping a ``T`` to a ``U`` if
there is an implicit conversion available from ``T`` to ``U``.

Version 2.2.0 (October 3, 2024)
-------------------------------

Expand Down
5 changes: 5 additions & 0 deletions include/nanobind/stl/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ template <typename... Ts> struct type_caster<std::variant<Ts...>> {
}

bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (flags & (uint8_t) cast_flags::convert) {
if ((try_variant<Ts>(src, flags & ~(uint8_t)cast_flags::convert, cleanup) || ...)){
return true;
}
}
return (try_variant<Ts>(src, flags, cleanup) || ...);
}

Expand Down
31 changes: 31 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,4 +475,35 @@ NB_MODULE(test_stl_ext, m) {
m.def("optional_cstr", [](std::optional<const char*> arg) {
return arg.value_or("none");
}, nb::arg().none());


// test73
struct BasicID1 {
uint64_t id;
BasicID1(uint64_t id) : id(id) {}
};

struct BasicID2 {
uint64_t id;
BasicID2(uint64_t id) : id(id) {}
};

nb::class_<BasicID1>(m, "BasicID1")
.def(nb::init<uint64_t>())
.def("__int__", [](const BasicID1& x) { return x.id; })
;

nb::class_<BasicID2>(m, "BasicID2")
.def(nb::init_implicit<uint64_t>());

using IDVariants = std::variant<std::monostate, BasicID2, BasicID1>;

struct IDHavingEvent {
IDVariants id;
IDHavingEvent() = default;
};

nb::class_<IDHavingEvent>(m, "IDHavingEvent")
.def(nb::init<>())
.def_rw("id", &IDHavingEvent::id);
}
6 changes: 6 additions & 0 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,3 +794,9 @@ def test71_null_input():
@skip_on_pypy # PyPy fails this test on Windows :-(
def test72_wstr():
assert t.pass_wstr('🎈') == '🎈'

def test73_variant_implicit_conversions():
event = t.IDHavingEvent()
assert event.id is None
event.id = t.BasicID1(78)
assert type(event.id) is t.BasicID1

0 comments on commit b246884

Please sign in to comment.