Skip to main content

compio_driver/
fd.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    ptr,
11    sync::atomic::Ordering,
12    task::Poll,
13};
14
15use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
16
17cfg_if::cfg_if! {
18    if #[cfg(feature = "sync")] {
19        use synchrony::sync;
20    } else {
21        use synchrony::unsync as sync;
22    }
23}
24
25use sync::{atomic::AtomicBool, shared::Shared, waker_slot::WakerSlot};
26
27#[derive(Debug)]
28struct Inner<T> {
29    fd: T,
30    // whether there is a future waiting
31    waits: AtomicBool,
32    waker: WakerSlot,
33}
34
35impl<T> RefUnwindSafe for Inner<T> {}
36
37/// A shared fd. It is passed to the operations to make sure the fd won't be
38/// closed before the operations complete.
39#[derive(Debug)]
40pub struct SharedFd<T>(Shared<Inner<T>>);
41
42impl<T: AsFd> SharedFd<T> {
43    /// Create the shared fd from an owned fd.
44    pub fn new(fd: T) -> Self {
45        unsafe { Self::new_unchecked(fd) }
46    }
47}
48
49impl<T> SharedFd<T> {
50    /// Create the shared fd.
51    ///
52    /// # Safety
53    /// * T should own the fd.
54    pub unsafe fn new_unchecked(fd: T) -> Self {
55        Self(Shared::new(Inner {
56            fd,
57            waits: AtomicBool::new(false),
58            waker: WakerSlot::new(),
59        }))
60    }
61
62    fn into_inner(self) -> Shared<Inner<T>> {
63        let this = ManuallyDrop::new(self);
64        // SAFETY: `this` is not dropped here.
65        unsafe { ptr::read(&this.0) }
66    }
67
68    /// Try to take the inner owned fd.
69    pub fn try_unwrap(self) -> Result<T, Self> {
70        let inner = self.into_inner();
71        Shared::try_unwrap(inner).map(|t| t.fd).map_err(|i| Self(i))
72    }
73
74    /// Wait and take the inner owned fd.
75    pub fn take(self) -> impl Future<Output = Option<T>> {
76        let inner = self.into_inner();
77
78        async move {
79            if !inner.waits.swap(true, Ordering::AcqRel) {
80                let mut inner = Some(inner);
81                poll_fn(move |cx| {
82                    let i = inner.take().unwrap();
83                    let this = match Shared::try_unwrap(i) {
84                        Ok(fd) => return Poll::Ready(Some(fd.fd)),
85                        Err(this) => this,
86                    };
87
88                    this.waker.register(cx.waker());
89
90                    match Shared::try_unwrap(this) {
91                        Ok(fd) => Poll::Ready(Some(fd.fd)),
92                        Err(tt) => {
93                            inner = Some(tt);
94                            Poll::Pending
95                        }
96                    }
97                })
98                .await
99            } else {
100                None
101            }
102        }
103    }
104}
105
106impl<T> Drop for SharedFd<T> {
107    fn drop(&mut self) {
108        // It's OK to wake multiple times.
109        if Shared::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
110            self.0.waker.wake()
111        }
112    }
113}
114
115impl<T: AsFd> AsFd for SharedFd<T> {
116    fn as_fd(&self) -> BorrowedFd<'_> {
117        self.0.fd.as_fd()
118    }
119}
120
121impl<T: AsFd> AsRawFd for SharedFd<T> {
122    fn as_raw_fd(&self) -> RawFd {
123        self.as_fd().as_raw_fd()
124    }
125}
126
127#[cfg(windows)]
128impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
129    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
130        unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
131    }
132}
133
134#[cfg(windows)]
135impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
136    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
137        unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
138    }
139}
140
141#[cfg(unix)]
142impl<T: FromRawFd> FromRawFd for SharedFd<T> {
143    unsafe fn from_raw_fd(fd: RawFd) -> Self {
144        unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
145    }
146}
147
148impl<T> Clone for SharedFd<T> {
149    fn clone(&self) -> Self {
150        Self(self.0.clone())
151    }
152}
153
154impl<T> Deref for SharedFd<T> {
155    type Target = T;
156
157    fn deref(&self) -> &Self::Target {
158        &self.0.fd
159    }
160}
161
162/// Get a clone of [`SharedFd`].
163pub trait ToSharedFd<T> {
164    /// Return a cloned [`SharedFd`].
165    fn to_shared_fd(&self) -> SharedFd<T>;
166}
167
168impl<T> ToSharedFd<T> for SharedFd<T> {
169    fn to_shared_fd(&self) -> SharedFd<T> {
170        self.clone()
171    }
172}