Skip to content

Commit

Permalink
simplify borrows by adding BoundListIterator::with_critical_section
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Jan 6, 2025
1 parent f57bab1 commit 2967b21
Showing 1 changed file with 23 additions and 44 deletions.
67 changes: 23 additions & 44 deletions src/types/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,17 +495,6 @@ pub struct BoundListIterator<'py> {
length: Length,
}

/// Helper to manage mutable borrows below
macro_rules! split_borrow {
($instance:expr, $index:ident, $length:ident, $list:ident) => {
let Self {
ref mut $index,
ref mut $length,
ref $list,
} = $instance;
};
}

impl<'py> BoundListIterator<'py> {
fn new(list: Bound<'py, PyList>) -> Self {
Self {
Expand Down Expand Up @@ -597,18 +586,28 @@ impl<'py> BoundListIterator<'py> {
None
}
}

fn with_critical_section<R>(
&mut self,
f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R,
) -> R {
let Self {
index,
length,
list,
} = self;
crate::sync::with_critical_section(list, || f(index, length, list))
}
}

impl<'py> Iterator for BoundListIterator<'py> {
type Item = Bound<'py, PyAny>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
split_borrow!(self, index, length, list);

#[cfg(Py_GIL_DISABLED)]
{
crate::sync::with_critical_section(list, || unsafe {
self.with_critical_section(|index, length, list| unsafe {
Self::next_unchecked(index, length, list)
})
}
Expand All @@ -635,9 +634,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
let mut accum = init;
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
accum = f(accum, x);
Expand All @@ -654,9 +651,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
F: FnMut(B, Self::Item) -> R,
R: std::ops::Try<Output = B>,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
let mut accum = init;
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
accum = f(accum, x)?
Expand All @@ -672,9 +667,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> bool,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
if !f(x) {
return false;
Expand All @@ -691,9 +684,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> bool,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
if f(x) {
return true;
Expand All @@ -710,9 +701,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
P: FnMut(&Self::Item) -> bool,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
if predicate(&x) {
return Some(x);
Expand All @@ -729,9 +718,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
F: FnMut(Self::Item) -> Option<B>,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
if let found @ Some(_) = f(x) {
return found;
Expand All @@ -748,9 +735,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
Self: Sized,
P: FnMut(Self::Item) -> bool,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
let mut acc = 0;
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
if predicate(x) {
Expand All @@ -766,11 +751,9 @@ impl<'py> Iterator for BoundListIterator<'py> {
impl DoubleEndedIterator for BoundListIterator<'_> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
split_borrow!(self, index, length, list);

#[cfg(Py_GIL_DISABLED)]
{
crate::sync::with_critical_section(list, || unsafe {
self.with_critical_section(|index, length, list| unsafe {
Self::next_back_unchecked(index, length, list)
})
}
Expand All @@ -791,9 +774,7 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
let mut accum = init;
while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } {
accum = f(accum, x);
Expand All @@ -810,9 +791,7 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
F: FnMut(B, Self::Item) -> R,
R: std::ops::Try<Output = B>,
{
split_borrow!(self, index, length, list);

crate::sync::with_critical_section(list, || {
self.with_critical_section(|index, length, list| {
let mut accum = init;
while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } {
accum = f(accum, x)?
Expand Down

0 comments on commit 2967b21

Please sign in to comment.