Skip to content

Commit

Permalink
Revise Impl of nth and nth_back. Impl advance_by
Browse files Browse the repository at this point in the history
  • Loading branch information
Owen-CH-Leung committed Jan 9, 2025
1 parent 1b19616 commit f6e95a8
Showing 1 changed file with 187 additions and 1 deletion.
188 changes: 187 additions & 1 deletion src/types/list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::iter::FusedIterator;

use crate::err::{self, PyResult};
use crate::ffi::{self, Py_ssize_t};
use crate::ffi_ptr_ext::FfiPtrExt;
Expand Down Expand Up @@ -547,6 +546,31 @@ impl<'py> BoundListIterator<'py> {
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
/// access to the list by holding a lock or by holding the innermost
/// critical section on the list.
#[inline]
#[cfg(not(Py_LIMITED_API))]
#[deny(unsafe_op_in_unsafe_fn)]
unsafe fn nth_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = unsafe { list.get_item_unchecked(target_index) };
index.0 = target_index + 1;
Some(item)
} else {
None
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
Expand Down Expand Up @@ -589,6 +613,31 @@ impl<'py> BoundListIterator<'py> {
}
}

/// # Safety
///
/// On the free-threaded build, caller must verify they have exclusive
/// access to the list by holding a lock or by holding the innermost
/// critical section on the list.
#[inline]
#[cfg(not(Py_LIMITED_API))]
#[deny(unsafe_op_in_unsafe_fn)]
unsafe fn nth_back_unchecked(
index: &mut Index,
length: &mut Length,
list: &Bound<'py, PyList>,
n: usize,
) -> Option<Bound<'py, PyAny>> {
let length_size = length.0.min(list.len());
if index.0 + n < length_size {
let target_index = length_size - n - 1;
let item = unsafe {list.get_item_unchecked(target_index)};
*length = Length(target_index);
Some(item)
} else {
None
}
}

#[cfg(not(Py_LIMITED_API))]
fn with_critical_section<R>(
&mut self,
Expand Down Expand Up @@ -625,6 +674,14 @@ impl<'py> Iterator for BoundListIterator<'py> {
}
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.with_critical_section(|index, length, list| unsafe {
Self::nth_unchecked(index, length, list, n)
})
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
Expand Down Expand Up @@ -750,6 +807,27 @@ impl<'py> Iterator for BoundListIterator<'py> {
None
})
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, feature = "nightly"))]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
self.with_critical_section(|index, length, list| {
let length = length.0.min(list.len());
let target_index = index.0 + n;
if index.0 + n < length {
let item = list.get_item(target_index);
match item {
Ok(_) => {
index.0 = target_index;
Ok(())
}
Err(_) => Err(NonZero::new(n - index.0))
}
} else {
Err(NonZero::new(n - index.0))
}
})
}
}

impl DoubleEndedIterator for BoundListIterator<'_> {
Expand All @@ -772,6 +850,14 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
}
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
self.with_critical_section(|index, length, list| unsafe {
Self::nth_back_unchecked(index, length, list, n)
})
}

#[inline]
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
fn rfold<B, F>(mut self, init: B, mut f: F) -> B
Expand Down Expand Up @@ -1502,4 +1588,104 @@ mod tests {
assert!(tuple.eq(tuple_expected).unwrap());
})
}

#[test]
fn test_iter_nth() {
Python::with_gil(|py| {
let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 8);
assert_eq!(iter.nth(1).unwrap().extract::<i32>().unwrap(), 10);
assert!(iter.nth(1).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next();
assert!(iter.nth(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth(10).is_none());

let v = vec![6, 7, 8, 9, 10];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();
let mut iter = list.iter();
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 9);
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10);

let mut iter = list.iter();
iter.nth_back(1);
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
assert!(iter.next().is_none());
});
}

#[test]
fn test_iter_nth_back() {
Python::with_gil(|py| {
let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter.nth_back(2).is_none());

let v: Vec<i32> = vec![];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(0).is_none());
assert!(iter.nth_back(1).is_none());

let v = vec![1, 2, 3];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert!(iter.nth_back(5).is_none());

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
iter.next_back(); // Consume the last element
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter.next_back().unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(iter.nth_back(0).unwrap().extract::<i32>().unwrap(), 1);

let v = vec![1, 2, 3, 4, 5];
let ob = (&v).into_pyobject(py).unwrap();
let list = ob.downcast::<PyList>().unwrap();

let mut iter = list.iter();
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(iter.nth_back(2).unwrap().extract::<i32>().unwrap(), 1);

let mut iter2 = list.iter();
iter2.next_back();
assert_eq!(iter2.nth_back(1).unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(iter2.next_back().unwrap().extract::<i32>().unwrap(), 2);

let mut iter3 = list.iter();
iter3.nth(1);
assert_eq!(iter3.nth_back(2).unwrap().extract::<i32>().unwrap(), 3);
assert!(iter3.nth_back(0).is_none());
});
}
}

0 comments on commit f6e95a8

Please sign in to comment.