Skip to main content

compio_io/compat/
async_stream.rs

1use std::{
2    fmt::Debug,
3    io,
4    marker::PhantomPinned,
5    mem::MaybeUninit,
6    pin::Pin,
7    task::{Context, Poll, Waker, ready},
8};
9
10use pin_project_lite::pin_project;
11
12use super::waker_array::WakerArrayRef;
13use crate::{
14    AsyncRead, AsyncWrite, PinBoxFuture,
15    compat::{SyncStream, SyncStreamReadHalf, SyncStreamWriteHalf},
16    util::{DEFAULT_BUF_SIZE, Splittable},
17};
18
19pin_project! {
20    /// A stream wrapper for [`futures_util::io`] traits.
21    pub struct AsyncStream<S: Splittable> {
22        #[pin]
23        read_inner: AsyncReadStream<S::ReadHalf>,
24        #[pin]
25        write_inner: AsyncWriteStream<S::WriteHalf>,
26        #[pin]
27        _p: PhantomPinned,
28    }
29}
30
31impl<S: Splittable> AsyncStream<S> {
32    /// Create [`AsyncStream`] with the stream and default buffer size.
33    pub fn new(stream: S) -> Self {
34        Self::new_impl(SyncStream::new(stream))
35    }
36
37    /// Create [`AsyncStream`] with the stream and buffer size.
38    pub fn with_capacity(cap: usize, stream: S) -> Self {
39        Self::new_impl(SyncStream::with_capacity(cap, stream))
40    }
41
42    /// Create [`AsyncStream`] with the stream, buffer capacity and max buffer
43    /// size.
44    pub fn with_limits(cap: usize, max_buffer_size: usize, stream: S) -> Self {
45        Self::new_impl(SyncStream::with_limits(cap, max_buffer_size, stream))
46    }
47
48    fn new_impl(inner: SyncStream<S>) -> Self {
49        let (read_inner, write_inner) = inner.split();
50        Self {
51            read_inner: AsyncReadStream::new_impl(read_inner),
52            write_inner: AsyncWriteStream::new_impl(write_inner),
53            _p: PhantomPinned,
54        }
55    }
56
57    /// Get the reference of the inner stream.
58    pub fn get_ref(&self) -> (&S::ReadHalf, &S::WriteHalf) {
59        (self.read_inner.get_ref(), self.write_inner.get_ref())
60    }
61
62    /// Returns a mutable reference to the underlying stream.
63    pub fn get_mut(&mut self) -> (&mut S::ReadHalf, &mut S::WriteHalf) {
64        (self.read_inner.get_mut(), self.write_inner.get_mut())
65    }
66
67    /// Consumes the `AsyncStream`, returning the underlying stream.
68    pub fn into_inner(self) -> (S::ReadHalf, S::WriteHalf) {
69        (self.read_inner.into_inner(), self.write_inner.into_inner())
70    }
71}
72
73pin_project! {
74    /// A read stream wrapper for [`futures_util::io`].
75    ///
76    /// It doesn't support write and shutdown operations, making looser
77    /// requirements on the inner stream.
78    pub struct AsyncReadStream<S> {
79        #[pin]
80        inner: SyncStreamReadHalf<S>,
81        read_future: Option<PinBoxFuture<io::Result<usize>>>,
82        read_waker: Option<Waker>,
83        read_uninit_waker: Option<Waker>,
84        read_buf_waker: Option<Waker>,
85        #[pin]
86        _p: PhantomPinned,
87    }
88}
89
90impl<S> AsyncReadStream<S> {
91    /// Create [`AsyncReadStream`] with the stream and default buffer size.
92    pub fn new(stream: S) -> Self {
93        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
94    }
95
96    /// Create [`AsyncReadStream`] with the stream and buffer size.
97    pub fn with_capacity(cap: usize, stream: S) -> Self {
98        Self::new_impl(SyncStreamReadHalf::with_limits(
99            cap,
100            super::DEFAULT_MAX_BUFFER,
101            stream,
102        ))
103    }
104
105    fn new_impl(inner: SyncStreamReadHalf<S>) -> Self {
106        Self {
107            inner,
108            read_future: None,
109            read_waker: None,
110            read_uninit_waker: None,
111            read_buf_waker: None,
112            _p: PhantomPinned,
113        }
114    }
115
116    /// Get the reference of the inner stream.
117    pub fn get_ref(&self) -> &S {
118        self.inner.get_ref()
119    }
120
121    /// Returns a mutable reference to the underlying stream.
122    pub fn get_mut(&mut self) -> &mut S {
123        self.inner.get_mut()
124    }
125
126    /// Consumes the `SyncStream`, returning the underlying stream.
127    pub fn into_inner(self) -> S {
128        self.inner.into_inner()
129    }
130}
131
132pin_project! {
133    /// A write stream wrapper for [`futures_util::io`].
134    ///
135    /// It doesn't support read operations, making looser requirements on the inner stream.
136    pub struct AsyncWriteStream<S> {
137        #[pin]
138        inner: SyncStreamWriteHalf<S>,
139        write_future: Option<PinBoxFuture<io::Result<usize>>>,
140        shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
141        write_waker: Option<Waker>,
142        flush_waker: Option<Waker>,
143        close_waker: Option<Waker>,
144        closed: bool,
145        #[pin]
146        _p: PhantomPinned,
147    }
148}
149
150impl<S> AsyncWriteStream<S> {
151    /// Create [`AsyncWriteStream`] with the stream and default buffer size.
152    pub fn new(stream: S) -> Self {
153        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
154    }
155
156    /// Create [`AsyncWriteStream`] with the stream and buffer size.
157    pub fn with_capacity(cap: usize, stream: S) -> Self {
158        Self::new_impl(SyncStreamWriteHalf::with_limits(
159            cap,
160            super::DEFAULT_MAX_BUFFER,
161            stream,
162        ))
163    }
164
165    fn new_impl(inner: SyncStreamWriteHalf<S>) -> Self {
166        Self {
167            inner,
168            write_future: None,
169            shutdown_future: None,
170            write_waker: None,
171            flush_waker: None,
172            close_waker: None,
173            closed: false,
174            _p: PhantomPinned,
175        }
176    }
177
178    /// Get the reference of the inner stream.
179    pub fn get_ref(&self) -> &S {
180        self.inner.get_ref()
181    }
182
183    /// Returns a mutable reference to the underlying stream.
184    pub fn get_mut(&mut self) -> &mut S {
185        self.inner.get_mut()
186    }
187
188    /// Consumes the `SyncStream`, returning the underlying stream.
189    pub fn into_inner(self) -> S {
190        self.inner.into_inner()
191    }
192}
193
194macro_rules! poll_future {
195    ($f:expr, $cx:expr, $e:expr) => {{
196        let mut future = match $f.take() {
197            Some(f) => f,
198            None => Box::pin($e),
199        };
200        let f = future.as_mut();
201        match f.poll($cx) {
202            Poll::Pending => {
203                $f.replace(future);
204                return Poll::Pending;
205            }
206            Poll::Ready(res) => res,
207        }
208    }};
209}
210
211macro_rules! poll_future_would_block {
212    ($cx:expr, $w:expr, $io:expr, $f:expr) => {{
213        match $io {
214            Ok(res) => {
215                $w.take();
216                return Poll::Ready(Ok(res));
217            }
218            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
219                ready!($f)?;
220            }
221            Err(e) => {
222                $w.take();
223                return Poll::Ready(Err(e));
224            }
225        }
226    }};
227}
228
229unsafe fn extend_lifetime_mut<T: ?Sized>(t: &mut T) -> &'static mut T {
230    unsafe { &mut *(t as *mut T) }
231}
232
233unsafe fn extend_lifetime<T: ?Sized>(t: &T) -> &'static T {
234    unsafe { &*(t as *const T) }
235}
236
237fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
238    if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
239        waker_slot.replace(waker.clone());
240    }
241}
242
243impl<S: Splittable + 'static> futures_util::AsyncRead for AsyncStream<S>
244where
245    S::ReadHalf: AsyncRead + Unpin,
246{
247    fn poll_read(
248        self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut [u8],
251    ) -> Poll<io::Result<usize>> {
252        self.project().read_inner.poll_read(cx, buf)
253    }
254}
255
256impl<S: Splittable + 'static> AsyncStream<S>
257where
258    S::ReadHalf: AsyncRead + Unpin,
259{
260    /// Attempt to read from the `AsyncRead` into `buf`.
261    ///
262    /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
263    pub fn poll_read_uninit(
264        self: Pin<&mut Self>,
265        cx: &mut Context<'_>,
266        buf: &mut [MaybeUninit<u8>],
267    ) -> Poll<io::Result<usize>> {
268        self.project().read_inner.poll_read_uninit(cx, buf)
269    }
270}
271
272impl<S: Splittable + 'static> futures_util::AsyncBufRead for AsyncStream<S>
273where
274    S::ReadHalf: AsyncRead + Unpin,
275{
276    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
277        self.project().read_inner.poll_fill_buf(cx)
278    }
279
280    fn consume(self: Pin<&mut Self>, amt: usize) {
281        self.project().read_inner.consume(amt)
282    }
283}
284
285impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
286    fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
287        let this = self.project();
288        // SAFETY:
289        // - The future won't live longer than the stream.
290        // - The stream is `Unpin`.
291        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
292        let arr = WakerArrayRef::new([
293            this.read_waker.as_ref(),
294            this.read_uninit_waker.as_ref(),
295            this.read_buf_waker.as_ref(),
296        ]);
297        arr.with(|waker| {
298            let cx = &mut Context::from_waker(waker);
299            let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
300            Poll::Ready(res)
301        })
302    }
303}
304
305impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncReadStream<S> {
306    fn poll_read(
307        mut self: Pin<&mut Self>,
308        cx: &mut Context<'_>,
309        buf: &mut [u8],
310    ) -> Poll<io::Result<usize>> {
311        replace_waker(self.as_mut().project().read_waker, cx.waker());
312        loop {
313            let this = self.as_mut().project();
314            poll_future_would_block!(
315                cx,
316                this.read_waker,
317                io::Read::read(this.inner.get_mut(), buf),
318                self.as_mut().poll_read_impl()
319            )
320        }
321    }
322}
323
324impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
325    /// Attempt to read from the `AsyncRead` into `buf`.
326    ///
327    /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
328    pub fn poll_read_uninit(
329        mut self: Pin<&mut Self>,
330        cx: &mut Context<'_>,
331        buf: &mut [MaybeUninit<u8>],
332    ) -> Poll<io::Result<usize>> {
333        replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
334        loop {
335            let this = self.as_mut().project();
336            poll_future_would_block!(
337                cx,
338                this.read_uninit_waker,
339                this.inner.get_mut().read_buf_uninit(buf),
340                self.as_mut().poll_read_impl()
341            )
342        }
343    }
344}
345impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncReadStream<S> {
346    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
347        replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
348        loop {
349            let this = self.as_mut().project();
350            poll_future_would_block!(
351                cx,
352                this.read_buf_waker,
353                // SAFETY: The buffer won't be accessed after the future is ready, and the future
354                // won't live longer than the stream.
355                io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
356                self.as_mut().poll_read_impl()
357            )
358        }
359    }
360
361    fn consume(self: Pin<&mut Self>, amt: usize) {
362        io::BufRead::consume(self.project().inner.get_mut(), amt)
363    }
364}
365
366impl<S: Splittable + 'static> futures_util::AsyncWrite for AsyncStream<S>
367where
368    S::WriteHalf: AsyncWrite + Unpin,
369{
370    fn poll_write(
371        self: Pin<&mut Self>,
372        cx: &mut Context<'_>,
373        buf: &[u8],
374    ) -> Poll<io::Result<usize>> {
375        self.project().write_inner.poll_write(cx, buf)
376    }
377
378    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
379        self.project().write_inner.poll_flush(cx)
380    }
381
382    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
383        self.project().write_inner.poll_close(cx)
384    }
385}
386
387impl<S: AsyncWrite + Unpin + 'static> AsyncWriteStream<S> {
388    fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
389        let this = self.project();
390        // SAFETY:
391        // - The future won't live longer than the stream.
392        // - The stream is `Unpin`.
393        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
394        let arr = WakerArrayRef::new([
395            this.write_waker.as_ref(),
396            this.flush_waker.as_ref(),
397            this.close_waker.as_ref(),
398        ]);
399        arr.with(|waker| {
400            let cx = &mut Context::from_waker(waker);
401            let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
402            Poll::Ready(res)
403        })
404    }
405
406    fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
407        if self.closed {
408            return Poll::Ready(Ok(()));
409        }
410        let this = self.project();
411        // SAFETY:
412        // - The future won't live longer than the stream.
413        // - The stream is `Unpin`.
414        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
415        let arr = WakerArrayRef::new([
416            this.write_waker.as_ref(),
417            this.flush_waker.as_ref(),
418            this.close_waker.as_ref(),
419        ]);
420        arr.with(|waker| {
421            let cx = &mut Context::from_waker(waker);
422            let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
423            Poll::Ready(res.inspect(|_| *this.closed = true))
424        })
425    }
426}
427
428impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncWriteStream<S> {
429    fn poll_write(
430        mut self: Pin<&mut Self>,
431        cx: &mut Context<'_>,
432        buf: &[u8],
433    ) -> Poll<io::Result<usize>> {
434        replace_waker(self.as_mut().project().write_waker, cx.waker());
435        if self.shutdown_future.is_some() {
436            debug_assert!(self.write_future.is_none());
437            ready!(self.as_mut().poll_close_impl())?;
438        }
439        loop {
440            let this = self.as_mut().project();
441            poll_future_would_block!(
442                cx,
443                this.write_waker,
444                io::Write::write(this.inner.get_mut(), buf),
445                self.as_mut().poll_flush_impl()
446            )
447        }
448    }
449
450    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
451        replace_waker(self.as_mut().project().flush_waker, cx.waker());
452        if self.shutdown_future.is_some() {
453            debug_assert!(self.write_future.is_none());
454            ready!(self.as_mut().poll_close_impl())?;
455        }
456        let res = ready!(self.as_mut().poll_flush_impl());
457        self.project().flush_waker.take();
458        Poll::Ready(res.map(|_| ()))
459    }
460
461    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
462        replace_waker(self.as_mut().project().close_waker, cx.waker());
463        // Avoid shutdown on flush because the inner buffer might be passed to the
464        // driver.
465        if self.write_future.is_some() || self.inner.has_pending_write() {
466            debug_assert!(self.shutdown_future.is_none());
467            ready!(self.as_mut().poll_flush_impl())?;
468        }
469        let res = ready!(self.as_mut().poll_close_impl());
470        self.project().close_waker.take();
471        Poll::Ready(res)
472    }
473}
474
475impl<S: Splittable> Debug for AsyncStream<S> {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        f.debug_struct("AsyncStream").finish_non_exhaustive()
478    }
479}
480
481impl<S: Splittable> Splittable for AsyncStream<S> {
482    type ReadHalf = AsyncReadStream<S::ReadHalf>;
483    type WriteHalf = AsyncWriteStream<S::WriteHalf>;
484
485    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
486        (self.read_inner, self.write_inner)
487    }
488}
489
490#[cfg(test)]
491mod test {
492    use futures_executor::block_on;
493    use futures_util::AsyncWriteExt;
494
495    use super::AsyncWriteStream;
496
497    #[test]
498    fn close() {
499        block_on(async {
500            let stream = AsyncWriteStream::new(Vec::<u8>::new());
501            let mut stream = std::pin::pin!(stream);
502            let n = stream.write(b"hello").await.unwrap();
503            assert_eq!(n, 5);
504            stream.close().await.unwrap();
505            assert_eq!(stream.get_ref(), b"hello");
506        })
507    }
508}