From b14f0460894c6b79fbd3b2ccc80b0182aef0ecf7 Mon Sep 17 00:00:00 2001 From: Tim Stumbaugh Date: Thu, 24 Oct 2024 14:18:41 -0600 Subject: [PATCH] Perform two passes in the variant caster When converting a Python object that wraps a `T` to a `std::variant`, where `U` is implicitly convertible from `T`, the variant caster will cast to the `U` (even though the Python object is definitely a `T`) The included test case demonstrates the issue. pybind11 does two pass conversion in this case. One can work around the issue by using `noconvert()` on the argument. But it seems that it would be a friendlier default to make type conversions in a variant work as expected. --- docs/changelog.rst | 8 ++++++++ include/nanobind/stl/variant.h | 5 +++++ tests/test_stl.cpp | 31 +++++++++++++++++++++++++++++++ tests/test_stl.py | 6 ++++++ tests/test_stl_ext.pyi.ref | 17 +++++++++++++++++ 5 files changed, 67 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5702a5d6..bde4da7e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -15,6 +15,14 @@ case, both modules must use the same nanobind ABI version, or they will be isolated from each other. Releases that don't explicitly mention an ABI version below inherit that of the preceding release. +Version TBD (unreleased) +------------------------ + +- The ``std::variant`` type_caster now does two passes when converting from Python. + The first pass is done without implicit conversions. This fixes an issue where + ``std::variant`` might cast a Python object wrapping a ``T`` to a ``U`` if + there is an implicit conversion available from ``T`` to ``U``. + Version 2.2.0 (October 3, 2024) ------------------------------- diff --git a/include/nanobind/stl/variant.h b/include/nanobind/stl/variant.h index d2804255..473ec8fa 100644 --- a/include/nanobind/stl/variant.h +++ b/include/nanobind/stl/variant.h @@ -45,6 +45,11 @@ template struct type_caster> { } bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (flags & (uint8_t) cast_flags::convert) { + if ((try_variant(src, flags & ~(uint8_t)cast_flags::convert, cleanup) || ...)){ + return true; + } + } return (try_variant(src, flags, cleanup) || ...); } diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 834ccd46..688806f0 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -475,4 +475,35 @@ NB_MODULE(test_stl_ext, m) { m.def("optional_cstr", [](std::optional arg) { return arg.value_or("none"); }, nb::arg().none()); + + + // test74 + struct BasicID1 { + uint64_t id; + BasicID1(uint64_t id) : id(id) {} + }; + + struct BasicID2 { + uint64_t id; + BasicID2(uint64_t id) : id(id) {} + }; + + nb::class_(m, "BasicID1") + .def(nb::init()) + .def("__int__", [](const BasicID1& x) { return x.id; }) + ; + + nb::class_(m, "BasicID2") + .def(nb::init_implicit()); + + using IDVariants = std::variant; + + struct IDHavingEvent { + IDVariants id; + IDHavingEvent() = default; + }; + + nb::class_(m, "IDHavingEvent") + .def(nb::init<>()) + .def_rw("id", &IDHavingEvent::id); } diff --git a/tests/test_stl.py b/tests/test_stl.py index a4e584f8..cf423475 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -798,3 +798,9 @@ def test72_wstr(): def test73_bad_input_to_set(): with pytest.raises(TypeError): t.set_in_value(None) + +def test74_variant_implicit_conversions(): + event = t.IDHavingEvent() + assert event.id is None + event.id = t.BasicID1(78) + assert type(event.id) is t.BasicID1 diff --git a/tests/test_stl_ext.pyi.ref b/tests/test_stl_ext.pyi.ref index e62b4749..5f74b109 100644 --- a/tests/test_stl_ext.pyi.ref +++ b/tests/test_stl_ext.pyi.ref @@ -4,6 +4,14 @@ import pathlib from typing import overload +class BasicID1: + def __init__(self, arg: int, /) -> None: ... + + def __int__(self) -> int: ... + +class BasicID2: + def __init__(self, arg: int, /) -> None: ... + class ClassWithMovableField: def __init__(self) -> None: ... @@ -38,6 +46,15 @@ class FuncWrapper: alive: int = ... """static read-only property""" +class IDHavingEvent: + def __init__(self) -> None: ... + + @property + def id(self) -> None | BasicID2 | BasicID1: ... + + @id.setter + def id(self, arg: BasicID2 | BasicID1, /) -> None: ... + class Movable: @overload def __init__(self) -> None: ...