Skip to main content

compio_driver/
fd.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    ptr,
11    sync::atomic::Ordering,
12    task::Poll,
13};
14
15use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
16
17cfg_select! {
18    feature = "sync" => {
19        use synchrony::sync;
20    }
21    _ => {
22        use synchrony::unsync as sync;
23    }
24}
25
26use sync::{atomic::AtomicBool, shared::Shared, waker_slot::WakerSlot};
27
28#[derive(Debug)]
29struct Inner<T> {
30    fd: T,
31    // whether there is a future waiting
32    waits: AtomicBool,
33    waker: WakerSlot,
34}
35
36impl<T> RefUnwindSafe for Inner<T> {}
37
38/// A shared fd. It is passed to the operations to make sure the fd won't be
39/// closed before the operations complete.
40#[derive(Debug)]
41pub struct SharedFd<T>(Shared<Inner<T>>);
42
43impl<T: AsFd> SharedFd<T> {
44    /// Create the shared fd from an owned fd.
45    pub fn new(fd: T) -> Self {
46        unsafe { Self::new_unchecked(fd) }
47    }
48}
49
50impl<T> SharedFd<T> {
51    /// Create the shared fd.
52    ///
53    /// # Safety
54    /// * T should own the fd.
55    pub unsafe fn new_unchecked(fd: T) -> Self {
56        Self(Shared::new(Inner {
57            fd,
58            waits: AtomicBool::new(false),
59            waker: WakerSlot::new(),
60        }))
61    }
62
63    fn into_inner(self) -> Shared<Inner<T>> {
64        let this = ManuallyDrop::new(self);
65        // SAFETY: `this` is not dropped here.
66        unsafe { ptr::read(&this.0) }
67    }
68
69    /// Try to take the inner owned fd.
70    pub fn try_unwrap(self) -> Result<T, Self> {
71        let inner = self.into_inner();
72        Shared::try_unwrap(inner).map(|t| t.fd).map_err(|i| Self(i))
73    }
74
75    /// Wait and take the inner owned fd.
76    pub fn take(self) -> impl Future<Output = Option<T>> {
77        let inner = self.into_inner();
78
79        async move {
80            if !inner.waits.swap(true, Ordering::AcqRel) {
81                let mut inner = Some(inner);
82                poll_fn(move |cx| {
83                    let i = inner.take().unwrap();
84                    let this = match Shared::try_unwrap(i) {
85                        Ok(fd) => return Poll::Ready(Some(fd.fd)),
86                        Err(this) => this,
87                    };
88
89                    this.waker.register(cx.waker());
90
91                    match Shared::try_unwrap(this) {
92                        Ok(fd) => Poll::Ready(Some(fd.fd)),
93                        Err(tt) => {
94                            inner = Some(tt);
95                            Poll::Pending
96                        }
97                    }
98                })
99                .await
100            } else {
101                None
102            }
103        }
104    }
105}
106
107impl<T> Drop for SharedFd<T> {
108    fn drop(&mut self) {
109        // It's OK to wake multiple times.
110        if Shared::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
111            self.0.waker.wake()
112        }
113    }
114}
115
116impl<T: AsFd> AsFd for SharedFd<T> {
117    fn as_fd(&self) -> BorrowedFd<'_> {
118        self.0.fd.as_fd()
119    }
120}
121
122impl<T: AsFd> AsRawFd for SharedFd<T> {
123    fn as_raw_fd(&self) -> RawFd {
124        self.as_fd().as_raw_fd()
125    }
126}
127
128#[cfg(windows)]
129impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
130    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
131        unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
132    }
133}
134
135#[cfg(windows)]
136impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
137    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
138        unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
139    }
140}
141
142#[cfg(unix)]
143impl<T: FromRawFd> FromRawFd for SharedFd<T> {
144    unsafe fn from_raw_fd(fd: RawFd) -> Self {
145        unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
146    }
147}
148
149impl<T> Clone for SharedFd<T> {
150    fn clone(&self) -> Self {
151        Self(self.0.clone())
152    }
153}
154
155impl<T> Deref for SharedFd<T> {
156    type Target = T;
157
158    fn deref(&self) -> &Self::Target {
159        &self.0.fd
160    }
161}
162
163/// Get a clone of [`SharedFd`].
164pub trait ToSharedFd<T> {
165    /// Return a cloned [`SharedFd`].
166    fn to_shared_fd(&self) -> SharedFd<T>;
167}
168
169impl<T> ToSharedFd<T> for SharedFd<T> {
170    fn to_shared_fd(&self) -> SharedFd<T> {
171        self.clone()
172    }
173}