diff --git a/src/types/sequence.rs b/src/types/sequence.rs index 62abb66fa6e..ce37b3d7259 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -1,3 +1,5 @@ +use std::iter::FusedIterator; + use crate::err::{self, DowncastError, PyDowncastError, PyErr, PyResult}; use crate::exceptions::PyTypeError; use crate::ffi_ptr_ext::FfiPtrExt; @@ -287,6 +289,9 @@ pub trait PySequenceMethods<'py>: crate::sealed::Sealed { /// Returns a fresh tuple based on the Sequence. fn to_tuple(&self) -> PyResult>; + + /// Returns an iterator over the Sequence's items. + fn iter(&self) -> BoundSequenceIterator<'py>; } impl<'py> PySequenceMethods<'py> for Bound<'py, PySequence> { @@ -462,6 +467,100 @@ impl<'py> PySequenceMethods<'py> for Bound<'py, PySequence> { .downcast_into_unchecked() } } + + #[inline] + fn iter(&self) -> BoundSequenceIterator<'py> { + BoundSequenceIterator::new(self.clone()) + } +} + +pub struct BoundSequenceIterator<'py> { + sequence: Bound<'py, PySequence>, + index: usize, + length: usize, +} + +impl<'py> BoundSequenceIterator<'py> { + fn new(sequence: Bound<'py, PySequence>) -> Self { + let length: usize = sequence.len().expect("failed to get sequence length"); + Self { + sequence, + index: 0, + length, + } + } + + unsafe fn get_item(&self, index: usize) -> PyResult> { + self.sequence.get_item(index) + } +} + +impl<'py> Iterator for BoundSequenceIterator<'py> { + type Item = PyResult>; + + #[inline] + fn next(&mut self) -> Option { + let length = self + .length + .min(self.sequence.len().expect("failed to get sequence length")); + + if self.index < length { + let item = unsafe { self.get_item(self.index) }; + self.index += 1; + Some(item) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl DoubleEndedIterator for BoundSequenceIterator<'_> { + #[inline] + fn next_back(&mut self) -> Option { + let length = self + .length + .min(self.sequence.len().expect("failed to get sequence length")); + + if self.index < length { + let item = unsafe { self.get_item(length - 1) }; + self.length = length - 1; + Some(item) + } else { + None + } + } +} + +impl ExactSizeIterator for BoundSequenceIterator<'_> { + fn len(&self) -> usize { + self.length.saturating_sub(self.index) + } +} + +impl FusedIterator for BoundSequenceIterator<'_> {} + +impl<'py> IntoIterator for Bound<'py, PySequence> { + type Item = PyResult>; + type IntoIter = BoundSequenceIterator<'py>; + + fn into_iter(self) -> Self::IntoIter { + BoundSequenceIterator::new(self) + } +} + +impl<'py> IntoIterator for &Bound<'py, PySequence> { + type Item = PyResult>; + type IntoIter = BoundSequenceIterator<'py>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } } #[inline] @@ -509,7 +608,7 @@ where }; let mut v = Vec::with_capacity(seq.len().unwrap_or(0)); - for item in seq.iter()? { + for item in seq.iter() { v.push(item?.extract::()?); } Ok(v) @@ -900,6 +999,23 @@ mod tests { }); } + #[test] + fn test_seq_iter_bound() { + use crate::types::any::PyAnyMethods; + + Python::with_gil(|py| { + let v: Vec = vec![1, 1, 2, 3, 5, 8]; + let ob = v.to_object(py); + let seq = ob.downcast_bound::(py).unwrap(); + let mut idx = 0; + for el in seq { + assert_eq!(v[idx], el.unwrap().extract::().unwrap()); + idx += 1; + } + assert_eq!(idx, v.len()); + }); + } + #[test] fn test_seq_strings() { Python::with_gil(|py| {