Rollup merge of #126879 - the8472:next-chunk-filter-drop, r=cuviper

fix Drop items getting leaked in Filter::next_chunk

The optimization only makes sense for non-drop elements anyway. Use the default implementation for items that are Drop instead.

It also simplifies the implementation.

fixes #126872
tracking issue #98326
This commit is contained in:
Matthias Krüger 2024-06-26 07:50:18 +02:00 committed by GitHub
commit cf22be186c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 45 deletions

View File

@ -3,7 +3,7 @@
use crate::num::NonZero;
use crate::ops::Try;
use core::array;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::mem::MaybeUninit;
use core::ops::ControlFlow;
/// An iterator that filters the elements of `iter` with `predicate`.
@ -27,6 +27,42 @@ pub(in crate::iter) fn new(iter: I, predicate: P) -> Filter<I, P> {
}
}
impl<I, P> Filter<I, P>
where
I: Iterator,
P: FnMut(&I::Item) -> bool,
{
#[inline]
fn next_chunk_dropless<const N: usize>(
&mut self,
) -> Result<[I::Item; N], array::IntoIter<I::Item, N>> {
let mut array: [MaybeUninit<I::Item>; N] = [const { MaybeUninit::uninit() }; N];
let mut initialized = 0;
let result = self.iter.try_for_each(|element| {
let idx = initialized;
// branchless index update combined with unconditionally copying the value even when
// it is filtered reduces branching and dependencies in the loop.
initialized = idx + (self.predicate)(&element) as usize;
// SAFETY: Loop conditions ensure the index is in bounds.
unsafe { array.get_unchecked_mut(idx) }.write(element);
if initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
});
match result {
ControlFlow::Break(()) => {
// SAFETY: The loop above is only explicitly broken when the array has been fully initialized
Ok(unsafe { MaybeUninit::array_assume_init(array) })
}
ControlFlow::Continue(()) => {
// SAFETY: The range is in bounds since the loop breaks when reaching N elements.
Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) })
}
}
}
}
#[stable(feature = "core_impl_debug", since = "1.9.0")]
impl<I: fmt::Debug, P> fmt::Debug for Filter<I, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@ -64,52 +100,16 @@ fn next(&mut self) -> Option<I::Item> {
fn next_chunk<const N: usize>(
&mut self,
) -> Result<[Self::Item; N], array::IntoIter<Self::Item, N>> {
let mut array: [MaybeUninit<Self::Item>; N] = [const { MaybeUninit::uninit() }; N];
struct Guard<'a, T> {
array: &'a mut [MaybeUninit<T>],
initialized: usize,
}
impl<T> Drop for Guard<'_, T> {
#[inline]
fn drop(&mut self) {
if const { crate::mem::needs_drop::<T>() } {
// SAFETY: self.initialized is always <= N, which also is the length of the array.
unsafe {
core::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
self.array.get_unchecked_mut(..self.initialized),
));
}
}
// avoid codegen for the dead branch
let fun = const {
if crate::mem::needs_drop::<I::Item>() {
array::iter_next_chunk::<I::Item, N>
} else {
Self::next_chunk_dropless::<N>
}
}
};
let mut guard = Guard { array: &mut array, initialized: 0 };
let result = self.iter.try_for_each(|element| {
let idx = guard.initialized;
guard.initialized = idx + (self.predicate)(&element) as usize;
// SAFETY: Loop conditions ensure the index is in bounds.
unsafe { guard.array.get_unchecked_mut(idx) }.write(element);
if guard.initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
});
let guard = ManuallyDrop::new(guard);
match result {
ControlFlow::Break(()) => {
// SAFETY: The loop above is only explicitly broken when the array has been fully initialized
Ok(unsafe { MaybeUninit::array_assume_init(array) })
}
ControlFlow::Continue(()) => {
let initialized = guard.initialized;
// SAFETY: The range is in bounds since the loop breaks when reaching N elements.
Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) })
}
}
fun(self)
}
#[inline]

View File

@ -1,4 +1,5 @@
use core::iter::*;
use std::rc::Rc;
#[test]
fn test_iterator_filter_count() {
@ -50,3 +51,15 @@ fn test_double_ended_filter() {
assert_eq!(it.next().unwrap(), &2);
assert_eq!(it.next_back(), None);
}
#[test]
fn test_next_chunk_does_not_leak() {
let drop_witness: [_; 5] = std::array::from_fn(|_| Rc::new(()));
let v = (0..5).map(|i| drop_witness[i].clone()).collect::<Vec<_>>();
let _ = v.into_iter().filter(|_| false).next_chunk::<1>();
for ref w in drop_witness {
assert_eq!(Rc::strong_count(w), 1);
}
}