compio_io\compat/
async_stream.rs

1use std::{
2    fmt::Debug,
3    io::{self, BufRead},
4    mem::MaybeUninit,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use crate::{PinBoxFuture, compat::SyncStream};
10
11/// A stream wrapper for [`futures_util::io`] traits.
12pub struct AsyncStream<S> {
13    // The futures keep the reference to the inner stream, so we need to pin
14    // the inner stream to make sure the reference is valid.
15    inner: Pin<Box<SyncStream<S>>>,
16    read_future: Option<PinBoxFuture<io::Result<usize>>>,
17    write_future: Option<PinBoxFuture<io::Result<usize>>>,
18    shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
19}
20
21impl<S> AsyncStream<S> {
22    /// Create [`AsyncStream`] with the stream and default buffer size.
23    pub fn new(stream: S) -> Self {
24        Self::new_impl(SyncStream::new(stream))
25    }
26
27    /// Create [`AsyncStream`] with the stream and buffer size.
28    pub fn with_capacity(cap: usize, stream: S) -> Self {
29        Self::new_impl(SyncStream::with_capacity(cap, stream))
30    }
31
32    fn new_impl(inner: SyncStream<S>) -> Self {
33        Self {
34            inner: Box::pin(inner),
35            read_future: None,
36            write_future: None,
37            shutdown_future: None,
38        }
39    }
40
41    /// Get the reference of the inner stream.
42    pub fn get_ref(&self) -> &S {
43        self.inner.get_ref()
44    }
45}
46
47macro_rules! poll_future {
48    ($f:expr, $cx:expr, $e:expr) => {{
49        let mut future = match $f.take() {
50            Some(f) => f,
51            None => Box::pin($e),
52        };
53        let f = future.as_mut();
54        match f.poll($cx) {
55            Poll::Pending => {
56                $f.replace(future);
57                return Poll::Pending;
58            }
59            Poll::Ready(res) => res,
60        }
61    }};
62}
63
64macro_rules! poll_future_would_block {
65    ($f:expr, $cx:expr, $e:expr, $io:expr) => {{
66        if let Some(mut f) = $f.take() {
67            if f.as_mut().poll($cx).is_pending() {
68                $f.replace(f);
69                return Poll::Pending;
70            }
71        }
72
73        match $io {
74            Ok(len) => Poll::Ready(Ok(len)),
75            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
76                $f.replace(Box::pin($e));
77                $cx.waker().wake_by_ref();
78                Poll::Pending
79            }
80            Err(e) => Poll::Ready(Err(e)),
81        }
82    }};
83}
84
85impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
86    fn poll_read(
87        mut self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89        buf: &mut [u8],
90    ) -> Poll<io::Result<usize>> {
91        // SAFETY:
92        // - The futures won't live longer than the stream.
93        // - The inner stream is pinned.
94        let inner: &'static mut SyncStream<S> =
95            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
96
97        poll_future_would_block!(
98            self.read_future,
99            cx,
100            inner.fill_read_buf(),
101            io::Read::read(inner, buf)
102        )
103    }
104}
105
106impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
107    /// Attempt to read from the `AsyncRead` into `buf`.
108    ///
109    /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
110    pub fn poll_read_uninit(
111        mut self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113        buf: &mut [MaybeUninit<u8>],
114    ) -> Poll<io::Result<usize>> {
115        let inner: &'static mut SyncStream<S> =
116            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
117        poll_future_would_block!(
118            self.read_future,
119            cx,
120            inner.fill_read_buf(),
121            inner.read_buf_uninit(buf)
122        )
123    }
124}
125
126impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
127    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
128        let inner: &'static mut SyncStream<S> =
129            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
130        poll_future_would_block!(
131            self.read_future,
132            cx,
133            inner.fill_read_buf(),
134            // SAFETY: anyway the slice won't be used after free.
135            io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
136        )
137    }
138
139    fn consume(mut self: Pin<&mut Self>, amt: usize) {
140        unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) }
141    }
142}
143
144impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
145    fn poll_write(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148        buf: &[u8],
149    ) -> Poll<io::Result<usize>> {
150        if self.shutdown_future.is_some() {
151            debug_assert!(self.write_future.is_none());
152            return Poll::Pending;
153        }
154
155        let inner: &'static mut SyncStream<S> =
156            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
157        poll_future_would_block!(
158            self.write_future,
159            cx,
160            inner.flush_write_buf(),
161            io::Write::write(inner, buf)
162        )
163    }
164
165    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166        if self.shutdown_future.is_some() {
167            debug_assert!(self.write_future.is_none());
168            return Poll::Pending;
169        }
170
171        let inner: &'static mut SyncStream<S> =
172            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
173        let res = poll_future!(self.write_future, cx, inner.flush_write_buf());
174        Poll::Ready(res.map(|_| ()))
175    }
176
177    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        // Avoid shutdown on flush because the inner buffer might be passed to the
179        // driver.
180        if self.write_future.is_some() {
181            debug_assert!(self.shutdown_future.is_none());
182            return Poll::Pending;
183        }
184
185        let inner: &'static mut SyncStream<S> =
186            unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
187        let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown());
188        Poll::Ready(res)
189    }
190}
191
192impl<S: Debug> Debug for AsyncStream<S> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.debug_struct("AsyncStream")
195            .field("inner", &self.inner)
196            .finish_non_exhaustive()
197    }
198}