Skip to content

Commit

Permalink
Perform two passes in the variant caster
Browse files Browse the repository at this point in the history
When converting a Python object that wraps a `T` to a
`std::variant<U,T>`, where `U` is implicitly convertible
from `T`, the variant caster will cast to the `U` (even though the
Python object is definitely a `T`)

The included test case demonstrates the issue. pybind11 does two pass
conversion in this case. One can work around the issue by using
`noconvert()` on the argument. But it seems that it would be a
friendlier default to make type conversions in a variant work as
expected.
  • Loading branch information
tjstum committed Nov 3, 2024
1 parent 017de5c commit b14f046
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ case, both modules must use the same nanobind ABI version, or they will be
isolated from each other. Releases that don't explicitly mention an ABI version
below inherit that of the preceding release.

Version TBD (unreleased)
------------------------

- The ``std::variant`` type_caster now does two passes when converting from Python.
The first pass is done without implicit conversions. This fixes an issue where
``std::variant<U, T>`` might cast a Python object wrapping a ``T`` to a ``U`` if
there is an implicit conversion available from ``T`` to ``U``.

Version 2.2.0 (October 3, 2024)
-------------------------------

Expand Down
5 changes: 5 additions & 0 deletions include/nanobind/stl/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ template <typename... Ts> struct type_caster<std::variant<Ts...>> {
}

bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (flags & (uint8_t) cast_flags::convert) {
if ((try_variant<Ts>(src, flags & ~(uint8_t)cast_flags::convert, cleanup) || ...)){
return true;
}
}
return (try_variant<Ts>(src, flags, cleanup) || ...);
}

Expand Down
31 changes: 31 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,4 +475,35 @@ NB_MODULE(test_stl_ext, m) {
m.def("optional_cstr", [](std::optional<const char*> arg) {
return arg.value_or("none");
}, nb::arg().none());


// test74
struct BasicID1 {
uint64_t id;
BasicID1(uint64_t id) : id(id) {}
};

struct BasicID2 {
uint64_t id;
BasicID2(uint64_t id) : id(id) {}
};

nb::class_<BasicID1>(m, "BasicID1")
.def(nb::init<uint64_t>())
.def("__int__", [](const BasicID1& x) { return x.id; })
;

nb::class_<BasicID2>(m, "BasicID2")
.def(nb::init_implicit<uint64_t>());

using IDVariants = std::variant<std::monostate, BasicID2, BasicID1>;

struct IDHavingEvent {
IDVariants id;
IDHavingEvent() = default;
};

nb::class_<IDHavingEvent>(m, "IDHavingEvent")
.def(nb::init<>())
.def_rw("id", &IDHavingEvent::id);
}
6 changes: 6 additions & 0 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,9 @@ def test72_wstr():
def test73_bad_input_to_set():
with pytest.raises(TypeError):
t.set_in_value(None)

def test74_variant_implicit_conversions():
event = t.IDHavingEvent()
assert event.id is None
event.id = t.BasicID1(78)
assert type(event.id) is t.BasicID1
17 changes: 17 additions & 0 deletions tests/test_stl_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ import pathlib
from typing import overload


class BasicID1:
def __init__(self, arg: int, /) -> None: ...

def __int__(self) -> int: ...

class BasicID2:
def __init__(self, arg: int, /) -> None: ...

class ClassWithMovableField:
def __init__(self) -> None: ...

Expand Down Expand Up @@ -38,6 +46,15 @@ class FuncWrapper:
alive: int = ...
"""static read-only property"""

class IDHavingEvent:
def __init__(self) -> None: ...

@property
def id(self) -> None | BasicID2 | BasicID1: ...

@id.setter
def id(self, arg: BasicID2 | BasicID1, /) -> None: ...

class Movable:
@overload
def __init__(self) -> None: ...
Expand Down

0 comments on commit b14f046

Please sign in to comment.