diff --git a/newsfragments/4810.added.md b/newsfragments/4810.added.md new file mode 100644 index 00000000000..e89e22e544d --- /dev/null +++ b/newsfragments/4810.added.md @@ -0,0 +1 @@ +Optimizes `nth` and `nth_back` for `BoundListIterator` \ No newline at end of file diff --git a/pyo3-benches/benches/bench_list.rs b/pyo3-benches/benches/bench_list.rs index cc790db37bf..7a19452455e 100644 --- a/pyo3-benches/benches/bench_list.rs +++ b/pyo3-benches/benches/bench_list.rs @@ -39,7 +39,33 @@ fn list_get_item(b: &mut Bencher<'_>) { }); } -#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] +fn list_nth(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +fn list_nth_back(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth_back(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +#[cfg(not(Py_LIMITED_API))] fn list_get_item_unchecked(b: &mut Bencher<'_>) { Python::with_gil(|py| { const LEN: usize = 50_000; @@ -66,6 +92,8 @@ fn sequence_from_list(b: &mut Bencher<'_>) { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("iter_list", iter_list); c.bench_function("list_new", list_new); + c.bench_function("list_nth", list_nth); + c.bench_function("list_nth_back", list_nth_back); c.bench_function("list_get_item", list_get_item); #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] c.bench_function("list_get_item_unchecked", list_get_item_unchecked); diff --git a/src/lib.rs b/src/lib.rs index e5146c81c00..3547c52e57c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![warn(missing_docs)] #![cfg_attr( feature = "nightly", - feature(auto_traits, negative_impls, try_trait_v2) + feature(auto_traits, negative_impls, try_trait_v2, iter_advance_by) )] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] // Deny some lints in doctests. diff --git a/src/types/list.rs b/src/types/list.rs index 76da36d00b9..8391696f003 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -1,16 +1,16 @@ -use std::iter::FusedIterator; - use crate::err::{self, PyResult}; use crate::ffi::{self, Py_ssize_t}; use crate::ffi_ptr_ext::FfiPtrExt; use crate::internal_tricks::get_ssize_index; +use crate::types::any::PyAnyMethods; +use crate::types::sequence::PySequenceMethods; use crate::types::{PySequence, PyTuple}; use crate::{ Borrowed, Bound, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyObject, Python, }; - -use crate::types::any::PyAnyMethods; -use crate::types::sequence::PySequenceMethods; +use std::iter::FusedIterator; +#[cfg(all(not(Py_LIMITED_API), feature = "nightly"))] +use std::num::NonZero; /// Represents a Python `list`. /// @@ -547,6 +547,46 @@ impl<'py> BoundListIterator<'py> { } } + #[inline] + #[cfg(all(not(Py_LIMITED_API), feature = "nightly"))] + #[deny(unsafe_op_in_unsafe_fn)] + unsafe fn nth_unchecked( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + n: usize, + ) -> Option> { + let length = length.0.min(list.len()); + let target_index = index.0 + n; + if index.0 + n < length { + let item = unsafe { list.get_item_unchecked(target_index) }; + index.0 = target_index + 1; + Some(item) + } else { + None + } + } + + #[inline] + #[cfg(all(Py_LIMITED_API, feature = "nightly"))] + #[deny(unsafe_op_in_unsafe_fn)] + fn nth( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + n: usize, + ) -> Option> { + let length = length.0.min(list.len()); + let target_index = index.0 + n; + if index.0 + n < length { + let item = list.get_item(target_index).expect("get-item failed"); + index.0 = target_index + 1; + Some(item) + } else { + None + } + } + /// # Safety /// /// On the free-threaded build, caller must verify they have exclusive @@ -589,6 +629,45 @@ impl<'py> BoundListIterator<'py> { } } + #[inline] + #[cfg(all(not(Py_LIMITED_API), feature = "nightly"))] + #[deny(unsafe_op_in_unsafe_fn)] + unsafe fn nth_back_unchecked( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + n: usize, + ) -> Option> { + let length_size = length.0.min(list.len()); + if index.0 + n < length_size { + let target_index = length_size - n - 1; + let item = unsafe { list.get_item_unchecked(target_index) }; + *length = Length(target_index); + Some(item) + } else { + None + } + } + + #[inline] + #[cfg(all(Py_LIMITED_API, feature = "nightly"))] + fn nth_back( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + n: usize, + ) -> Option> { + let length_size = length.0.min(list.len()); + if index.0 + n < length_size { + let target_index = length_size - n - 1; + let item = list.get_item(target_index).expect("get-item failed"); + *length = Length(target_index); + Some(item) + } else { + None + } + } + #[cfg(not(Py_LIMITED_API))] fn with_critical_section( &mut self, @@ -625,6 +704,26 @@ impl<'py> Iterator for BoundListIterator<'py> { } } + #[inline] + #[cfg(feature = "nightly")] + fn nth(&mut self, n: usize) -> Option { + #[cfg(not(Py_LIMITED_API))] + { + self.with_critical_section(|index, length, list| unsafe { + Self::nth_unchecked(index, length, list, n) + }) + } + #[cfg(Py_LIMITED_API)] + { + let Self { + index, + length, + list, + } = self; + Self::nth(index, length, list, n) + } + } + #[inline] fn size_hint(&self) -> (usize, Option) { let len = self.len(); @@ -750,6 +849,32 @@ impl<'py> Iterator for BoundListIterator<'py> { None }) } + + #[inline] + #[cfg(all(not(Py_LIMITED_API), feature = "nightly"))] + fn advance_by(&mut self, n: usize) -> Result<(), NonZero> { + self.with_critical_section(|index, length, list| { + let max_len = length.0.min(list.len()); + let currently_at = index.0; + if currently_at >= max_len { + if n == 0 { + return Ok(()); + } else { + return Err(unsafe { NonZero::new_unchecked(n) }); + } + } + + let items_left = max_len - currently_at; + if n <= items_left { + index.0 += n; + Ok(()) + } else { + index.0 = max_len; + let remainder = n - items_left; + Err(unsafe { NonZero::new_unchecked(remainder) }) + } + }) + } } impl DoubleEndedIterator for BoundListIterator<'_> { @@ -772,6 +897,26 @@ impl DoubleEndedIterator for BoundListIterator<'_> { } } + #[inline] + #[cfg(feature = "nightly")] + fn nth_back(&mut self, n: usize) -> Option { + #[cfg(not(Py_LIMITED_API))] + { + self.with_critical_section(|index, length, list| unsafe { + Self::nth_back_unchecked(index, length, list, n) + }) + } + #[cfg(Py_LIMITED_API)] + { + let Self { + index, + length, + list, + } = self; + Self::nth_back(index, length, list, n) + } + } + #[inline] #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] fn rfold(mut self, init: B, mut f: F) -> B @@ -839,6 +984,8 @@ mod tests { use crate::types::sequence::PySequenceMethods; use crate::types::{PyList, PyTuple}; use crate::{ffi, IntoPyObject, PyResult, Python}; + #[cfg(feature = "nightly")] + use std::num::NonZero; #[test] fn test_new() { @@ -1502,4 +1649,130 @@ mod tests { assert!(tuple.eq(tuple_expected).unwrap()); }) } + + #[test] + fn test_iter_nth() { + Python::with_gil(|py| { + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next(); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 8); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 10); + assert!(iter.nth(1).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next(); + assert!(iter.nth(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth(10).is_none()); + + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + let mut iter = list.iter(); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 6); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 9); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 10); + + let mut iter = list.iter(); + iter.nth_back(1); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 8); + assert!(iter.next().is_none()); + }); + } + + #[test] + fn test_iter_nth_back() { + Python::with_gil(|py| { + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 5); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert!(iter.nth_back(2).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(0).is_none()); + assert!(iter.nth_back(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(5).is_none()); + + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next_back(); // Consume the last element + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter.next_back().unwrap().extract::().unwrap(), 2); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 1); + + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 4); + assert_eq!(iter.nth_back(2).unwrap().extract::().unwrap(), 1); + + let mut iter2 = list.iter(); + iter2.next_back(); + assert_eq!(iter2.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter2.next_back().unwrap().extract::().unwrap(), 2); + + let mut iter3 = list.iter(); + iter3.nth(1); + assert_eq!(iter3.nth_back(2).unwrap().extract::().unwrap(), 3); + assert!(iter3.nth_back(0).is_none()); + }); + } + + #[cfg(feature = "nightly")] + #[test] + fn test_iter_advance_by() { + Python::with_gil(|py| { + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.advance_by(2), Ok(())); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 3); + assert_eq!(iter.advance_by(0), Ok(())); + assert_eq!(iter.advance_by(100), Err(NonZero::new(98).unwrap())); + + let mut iter2 = list.iter(); + assert_eq!(iter2.advance_by(6), Err(NonZero::new(1).unwrap())); + + let mut iter3 = list.iter(); + assert_eq!(iter3.advance_by(5), Ok(())); + + let mut iter4 = list.iter(); + assert_eq!(iter4.advance_by(0), Ok(())); + assert_eq!(iter4.next().unwrap().extract::().unwrap(), 1); + }) + } }