summaryrefslogtreecommitdiff
path: root/rust/pin-init/examples/mutex.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/pin-init/examples/mutex.rs')
-rw-r--r--rust/pin-init/examples/mutex.rs209
1 files changed, 209 insertions, 0 deletions
diff --git a/rust/pin-init/examples/mutex.rs b/rust/pin-init/examples/mutex.rs
new file mode 100644
index 000000000000..073bb79341d1
--- /dev/null
+++ b/rust/pin-init/examples/mutex.rs
@@ -0,0 +1,209 @@
+// SPDX-License-Identifier: Apache-2.0 OR MIT
+
+#![allow(clippy::undocumented_unsafe_blocks)]
+#![cfg_attr(feature = "alloc", feature(allocator_api))]
+#![allow(clippy::missing_safety_doc)]
+
+use core::{
+ cell::{Cell, UnsafeCell},
+ marker::PhantomPinned,
+ ops::{Deref, DerefMut},
+ pin::Pin,
+ sync::atomic::{AtomicBool, Ordering},
+};
+use std::{
+ sync::Arc,
+ thread::{self, park, sleep, Builder, Thread},
+ time::Duration,
+};
+
+use pin_init::*;
+#[expect(unused_attributes)]
+#[path = "./linked_list.rs"]
+pub mod linked_list;
+use linked_list::*;
+
+pub struct SpinLock {
+ inner: AtomicBool,
+}
+
+impl SpinLock {
+ #[inline]
+ pub fn acquire(&self) -> SpinLockGuard<'_> {
+ while self
+ .inner
+ .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
+ .is_err()
+ {
+ while self.inner.load(Ordering::Relaxed) {
+ thread::yield_now();
+ }
+ }
+ SpinLockGuard(self)
+ }
+
+ #[inline]
+ #[allow(clippy::new_without_default)]
+ pub const fn new() -> Self {
+ Self {
+ inner: AtomicBool::new(false),
+ }
+ }
+}
+
+pub struct SpinLockGuard<'a>(&'a SpinLock);
+
+impl Drop for SpinLockGuard<'_> {
+ #[inline]
+ fn drop(&mut self) {
+ self.0.inner.store(false, Ordering::Release);
+ }
+}
+
+#[pin_data]
+pub struct CMutex<T> {
+ #[pin]
+ wait_list: ListHead,
+ spin_lock: SpinLock,
+ locked: Cell<bool>,
+ #[pin]
+ data: UnsafeCell<T>,
+}
+
+impl<T> CMutex<T> {
+ #[inline]
+ pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
+ pin_init!(CMutex {
+ wait_list <- ListHead::new(),
+ spin_lock: SpinLock::new(),
+ locked: Cell::new(false),
+ data <- unsafe {
+ pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
+ val.__pinned_init(slot.cast::<T>())
+ })
+ },
+ })
+ }
+
+ #[inline]
+ pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
+ let mut sguard = self.spin_lock.acquire();
+ if self.locked.get() {
+ stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
+ // println!("wait list length: {}", self.wait_list.size());
+ while self.locked.get() {
+ drop(sguard);
+ park();
+ sguard = self.spin_lock.acquire();
+ }
+ // This does have an effect, as the ListHead inside wait_entry implements Drop!
+ #[expect(clippy::drop_non_drop)]
+ drop(wait_entry);
+ }
+ self.locked.set(true);
+ unsafe {
+ Pin::new_unchecked(CMutexGuard {
+ mtx: self,
+ _pin: PhantomPinned,
+ })
+ }
+ }
+
+ #[allow(dead_code)]
+ pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
+ // SAFETY: we have an exclusive reference and thus nobody has access to data.
+ unsafe { &mut *self.data.get() }
+ }
+}
+
+unsafe impl<T: Send> Send for CMutex<T> {}
+unsafe impl<T: Send> Sync for CMutex<T> {}
+
+pub struct CMutexGuard<'a, T> {
+ mtx: &'a CMutex<T>,
+ _pin: PhantomPinned,
+}
+
+impl<T> Drop for CMutexGuard<'_, T> {
+ #[inline]
+ fn drop(&mut self) {
+ let sguard = self.mtx.spin_lock.acquire();
+ self.mtx.locked.set(false);
+ if let Some(list_field) = self.mtx.wait_list.next() {
+ let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
+ unsafe { (*wait_entry).thread.unpark() };
+ }
+ drop(sguard);
+ }
+}
+
+impl<T> Deref for CMutexGuard<'_, T> {
+ type Target = T;
+
+ #[inline]
+ fn deref(&self) -> &Self::Target {
+ unsafe { &*self.mtx.data.get() }
+ }
+}
+
+impl<T> DerefMut for CMutexGuard<'_, T> {
+ #[inline]
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ unsafe { &mut *self.mtx.data.get() }
+ }
+}
+
+#[pin_data]
+#[repr(C)]
+struct WaitEntry {
+ #[pin]
+ wait_list: ListHead,
+ thread: Thread,
+}
+
+impl WaitEntry {
+ #[inline]
+ fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
+ pin_init!(Self {
+ thread: thread::current(),
+ wait_list <- ListHead::insert_prev(list),
+ })
+ }
+}
+
+#[cfg(not(any(feature = "std", feature = "alloc")))]
+fn main() {}
+
+#[allow(dead_code)]
+#[cfg_attr(test, test)]
+#[cfg(any(feature = "std", feature = "alloc"))]
+fn main() {
+ let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
+ let mut handles = vec![];
+ let thread_count = 20;
+ let workload = if cfg!(miri) { 100 } else { 1_000 };
+ for i in 0..thread_count {
+ let mtx = mtx.clone();
+ handles.push(
+ Builder::new()
+ .name(format!("worker #{i}"))
+ .spawn(move || {
+ for _ in 0..workload {
+ *mtx.lock() += 1;
+ }
+ println!("{i} halfway");
+ sleep(Duration::from_millis((i as u64) * 10));
+ for _ in 0..workload {
+ *mtx.lock() += 1;
+ }
+ println!("{i} finished");
+ })
+ .expect("should not fail"),
+ );
+ }
+ for h in handles {
+ h.join().expect("thread panicked");
+ }
+ println!("{:?}", &*mtx.lock());
+ assert_eq!(*mtx.lock(), workload * thread_count * 2);
+}