From 27941eeb1eaaec73a722b8c1d959aa01ec956368 Mon Sep 17 00:00:00 2001 From: Marcin Wojdyr Date: Fri, 4 Oct 2024 15:34:57 +0200 Subject: [PATCH] optional_caster::from_python: reset value on None (fixes #747) --- include/nanobind/stl/detail/nb_optional.h | 5 +++-- tests/test_stl.cpp | 4 ++++ tests/test_stl.py | 1 + tests/test_stl_ext.pyi.ref | 2 ++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/nanobind/stl/detail/nb_optional.h b/include/nanobind/stl/detail/nb_optional.h index 78365c97..deda2189 100644 --- a/include/nanobind/stl/detail/nb_optional.h +++ b/include/nanobind/stl/detail/nb_optional.h @@ -21,9 +21,10 @@ struct optional_caster { NB_TYPE_CASTER(Optional, optional_name(Caster::Name)) bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { - if (src.is_none()) - // default-constructed value is already empty + if (src.is_none()) { + value.reset(); return true; + } Caster caster; if (!caster.from_python(src, flags_for_local_caster(flags), cleanup) || diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index bd75db5c..834ccd46 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -464,6 +464,10 @@ NB_MODULE(test_stl_ext, m) { return x; }); + m.def("vector_optional_str", [](const std::vector>& x) { + return x; + }); + m.def("pass_wstr", [](std::wstring ws) { return ws; }); // uncomment to see compiler error: diff --git a/tests/test_stl.py b/tests/test_stl.py index bb2d4e1e..f5fc46f0 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -782,6 +782,7 @@ def test69_complex_array(): def test70_vec_char(): assert isinstance(t.vector_str("123"), str) assert isinstance(t.vector_str(["123", "345"]), list) + assert t.vector_optional_str(["abc", None]) == ["abc", None] def test71_null_input(): diff --git a/tests/test_stl_ext.pyi.ref b/tests/test_stl_ext.pyi.ref index 72a0f5fe..e62b4749 100644 --- a/tests/test_stl_ext.pyi.ref +++ b/tests/test_stl_ext.pyi.ref @@ -233,6 +233,8 @@ def vec_return_copyable() -> list[Copyable]: ... def vec_return_movable() -> list[Movable]: ... +def vector_optional_str(arg: Sequence[str | None], /) -> list[str | None]: ... + @overload def vector_str(arg: Sequence[str], /) -> list[str]: ...