From 6837b812e69a0c024021e3a828bcdc5c7d8f1b4d Mon Sep 17 00:00:00 2001 From: Wim Looman Date: Sun, 28 Jan 2024 12:19:41 +0100 Subject: [PATCH] Fix some `Arc` allocator leaks This doesn't matter for the stable `Global` allocator as it is a ZST singleton, but other allocators may rely on all instances being dropped. --- library/alloc/src/sync.rs | 44 ++++++++++------------ library/alloc/src/sync/tests.rs | 65 ++++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 29 deletions(-) diff --git a/library/alloc/src/sync.rs b/library/alloc/src/sync.rs index 48c8d9d113b..29b00c97b47 100644 --- a/library/alloc/src/sync.rs +++ b/library/alloc/src/sync.rs @@ -279,6 +279,12 @@ unsafe fn from_ptr(ptr: *mut ArcInner) -> Self { } impl Arc { + #[inline] + fn internal_into_inner_with_allocator(self) -> (NonNull>, A) { + let this = mem::ManuallyDrop::new(self); + (this.ptr, unsafe { ptr::read(&this.alloc) }) + } + #[inline] unsafe fn from_inner_in(ptr: NonNull>, alloc: A) -> Self { Self { ptr, phantom: PhantomData, alloc } @@ -1271,12 +1277,9 @@ impl Arc, A> { #[unstable(feature = "new_uninit", issue = "63291")] #[must_use = "`self` will be dropped if the result is not used"] #[inline] - pub unsafe fn assume_init(self) -> Arc - where - A: Clone, - { - let md_self = mem::ManuallyDrop::new(self); - unsafe { Arc::from_inner_in(md_self.ptr.cast(), md_self.alloc.clone()) } + pub unsafe fn assume_init(self) -> Arc { + let (ptr, alloc) = self.internal_into_inner_with_allocator(); + unsafe { Arc::from_inner_in(ptr.cast(), alloc) } } } @@ -1316,12 +1319,9 @@ impl Arc<[mem::MaybeUninit], A> { #[unstable(feature = "new_uninit", issue = "63291")] #[must_use = "`self` will be dropped if the result is not used"] #[inline] - pub unsafe fn assume_init(self) -> Arc<[T], A> - where - A: Clone, - { - let md_self = mem::ManuallyDrop::new(self); - unsafe { Arc::from_ptr_in(md_self.ptr.as_ptr() as _, md_self.alloc.clone()) } + pub unsafe fn assume_init(self) -> Arc<[T], A> { + let (ptr, alloc) = self.internal_into_inner_with_allocator(); + unsafe { Arc::from_ptr_in(ptr.as_ptr() as _, alloc) } } } @@ -2409,7 +2409,7 @@ fn drop(&mut self) { } } -impl Arc { +impl Arc { /// Attempt to downcast the `Arc` to a concrete type. /// /// # Examples @@ -2436,10 +2436,8 @@ pub fn downcast(self) -> Result, Self> { if (*self).is::() { unsafe { - let ptr = self.ptr.cast::>(); - let alloc = self.alloc.clone(); - mem::forget(self); - Ok(Arc::from_inner_in(ptr, alloc)) + let (ptr, alloc) = self.internal_into_inner_with_allocator(); + Ok(Arc::from_inner_in(ptr.cast(), alloc)) } } else { Err(self) @@ -2479,10 +2477,8 @@ pub unsafe fn downcast_unchecked(self) -> Arc T: Any + Send + Sync, { unsafe { - let ptr = self.ptr.cast::>(); - let alloc = self.alloc.clone(); - mem::forget(self); - Arc::from_inner_in(ptr, alloc) + let (ptr, alloc) = self.internal_into_inner_with_allocator(); + Arc::from_inner_in(ptr.cast(), alloc) } } } @@ -3438,13 +3434,13 @@ fn from(rc: Arc) -> Self { } #[stable(feature = "boxed_slice_try_from", since = "1.43.0")] -impl TryFrom> for Arc<[T; N], A> { +impl TryFrom> for Arc<[T; N], A> { type Error = Arc<[T], A>; fn try_from(boxed_slice: Arc<[T], A>) -> Result { if boxed_slice.len() == N { - let alloc = boxed_slice.alloc.clone(); - Ok(unsafe { Arc::from_raw_in(Arc::into_raw(boxed_slice) as *mut [T; N], alloc) }) + let (ptr, alloc) = boxed_slice.internal_into_inner_with_allocator(); + Ok(unsafe { Arc::from_inner_in(ptr.cast(), alloc) }) } else { Err(boxed_slice) } diff --git a/library/alloc/src/sync/tests.rs b/library/alloc/src/sync/tests.rs index d37e45569cf..49eae718c16 100644 --- a/library/alloc/src/sync/tests.rs +++ b/library/alloc/src/sync/tests.rs @@ -1,13 +1,15 @@ use super::*; use std::clone::Clone; +use std::mem::MaybeUninit; use std::option::Option::None; +use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::SeqCst; use std::sync::mpsc::channel; use std::sync::Mutex; use std::thread; -struct Canary(*mut atomic::AtomicUsize); +struct Canary(*mut AtomicUsize); impl Drop for Canary { fn drop(&mut self) { @@ -21,6 +23,37 @@ fn drop(&mut self) { } } +struct AllocCanary<'a>(&'a AtomicUsize); + +impl<'a> AllocCanary<'a> { + fn new(counter: &'a AtomicUsize) -> Self { + counter.fetch_add(1, SeqCst); + Self(counter) + } +} + +unsafe impl Allocator for AllocCanary<'_> { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + std::alloc::Global.allocate(layout) + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + unsafe { std::alloc::Global.deallocate(ptr, layout) } + } +} + +impl Clone for AllocCanary<'_> { + fn clone(&self) -> Self { + Self::new(self.0) + } +} + +impl Drop for AllocCanary<'_> { + fn drop(&mut self) { + self.0.fetch_sub(1, SeqCst); + } +} + #[test] #[cfg_attr(target_os = "emscripten", ignore)] fn manually_share_arc() { @@ -295,16 +328,16 @@ struct Cycle { #[test] fn drop_arc() { - let mut canary = atomic::AtomicUsize::new(0); - let x = Arc::new(Canary(&mut canary as *mut atomic::AtomicUsize)); + let mut canary = AtomicUsize::new(0); + let x = Arc::new(Canary(&mut canary as *mut AtomicUsize)); drop(x); assert!(canary.load(Acquire) == 1); } #[test] fn drop_arc_weak() { - let mut canary = atomic::AtomicUsize::new(0); - let arc = Arc::new(Canary(&mut canary as *mut atomic::AtomicUsize)); + let mut canary = AtomicUsize::new(0); + let arc = Arc::new(Canary(&mut canary as *mut AtomicUsize)); let arc_weak = Arc::downgrade(&arc); assert!(canary.load(Acquire) == 0); drop(arc); @@ -660,3 +693,25 @@ fn arc_drop_dereferenceable_race() { thread.join().unwrap(); } } + +#[test] +fn arc_doesnt_leak_allocator() { + let counter = AtomicUsize::new(0); + + { + let arc: Arc = Arc::new_in(5usize, AllocCanary::new(&counter)); + drop(arc.downcast::().unwrap()); + + let arc: Arc = Arc::new_in(5usize, AllocCanary::new(&counter)); + drop(unsafe { arc.downcast_unchecked::() }); + + let arc = Arc::new_in(MaybeUninit::::new(5usize), AllocCanary::new(&counter)); + drop(unsafe { arc.assume_init() }); + + let arc: Arc<[MaybeUninit], _> = + Arc::new_zeroed_slice_in(5, AllocCanary::new(&counter)); + drop(unsafe { arc.assume_init() }); + } + + assert_eq!(counter.load(SeqCst), 0); +}