Skip to main content

compio_io/compat/
async_stream.rs

1use std::{
2    fmt::Debug,
3    io::{self, BufRead},
4    marker::PhantomPinned,
5    mem::MaybeUninit,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll, Wake, Waker, ready},
9};
10
11use pin_project_lite::pin_project;
12
13use crate::{AsyncRead, AsyncWrite, PinBoxFuture, compat::SyncStream, util::DEFAULT_BUF_SIZE};
14
15pin_project! {
16    /// A stream wrapper for [`futures_util::io`] traits.
17    pub struct AsyncStream<S> {
18        #[pin]
19        inner: SyncStream<S>,
20        read_future: Option<PinBoxFuture<io::Result<usize>>>,
21        write_future: Option<PinBoxFuture<io::Result<usize>>>,
22        shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
23        read_waker: Option<Waker>,
24        read_uninit_waker: Option<Waker>,
25        read_buf_waker: Option<Waker>,
26        write_waker: Option<Waker>,
27        flush_waker: Option<Waker>,
28        close_waker: Option<Waker>,
29        #[pin]
30        _p: PhantomPinned,
31    }
32}
33
34impl<S> AsyncStream<S> {
35    /// Create [`AsyncStream`] with the stream and default buffer size.
36    pub fn new(stream: S) -> Self {
37        Self::new_impl(SyncStream::new(stream))
38    }
39
40    /// Create [`AsyncStream`] with the stream and buffer size.
41    pub fn with_capacity(cap: usize, stream: S) -> Self {
42        Self::new_impl(SyncStream::with_capacity(cap, stream))
43    }
44
45    fn new_impl(inner: SyncStream<S>) -> Self {
46        Self {
47            inner,
48            read_future: None,
49            write_future: None,
50            shutdown_future: None,
51            read_waker: None,
52            read_uninit_waker: None,
53            read_buf_waker: None,
54            write_waker: None,
55            flush_waker: None,
56            close_waker: None,
57            _p: PhantomPinned,
58        }
59    }
60
61    /// Get the reference of the inner stream.
62    pub fn get_ref(&self) -> &S {
63        self.inner.get_ref()
64    }
65
66    /// Returns a mutable reference to the underlying stream.
67    pub fn get_mut(&mut self) -> &mut S {
68        self.inner.get_mut()
69    }
70
71    /// Consumes the `SyncStream`, returning the underlying stream.
72    pub fn into_inner(self) -> S {
73        self.inner.into_inner()
74    }
75}
76
77pin_project! {
78    /// A read stream wrapper for [`futures_util::io`].
79    ///
80    /// It doesn't support write and shutdown operations, making looser
81    /// requirements on the inner stream.
82    pub struct AsyncReadStream<S> {
83        #[pin]
84        inner: SyncStream<S>,
85        read_future: Option<PinBoxFuture<io::Result<usize>>>,
86        read_waker: Option<Waker>,
87        read_uninit_waker: Option<Waker>,
88        read_buf_waker: Option<Waker>,
89        #[pin]
90        _p: PhantomPinned,
91    }
92}
93
94impl<S> AsyncReadStream<S> {
95    /// Create [`AsyncReadStream`] with the stream and default buffer size.
96    pub fn new(stream: S) -> Self {
97        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
98    }
99
100    /// Create [`AsyncReadStream`] with the stream and buffer size.
101    pub fn with_capacity(cap: usize, stream: S) -> Self {
102        Self::new_impl(SyncStream::with_limits2(
103            cap,
104            0,
105            cap,
106            SyncStream::<S>::DEFAULT_MAX_BUFFER,
107            stream,
108        ))
109    }
110
111    fn new_impl(inner: SyncStream<S>) -> Self {
112        Self {
113            inner,
114            read_future: None,
115            read_waker: None,
116            read_uninit_waker: None,
117            read_buf_waker: None,
118            _p: PhantomPinned,
119        }
120    }
121
122    /// Get the reference of the inner stream.
123    pub fn get_ref(&self) -> &S {
124        self.inner.get_ref()
125    }
126
127    /// Returns a mutable reference to the underlying stream.
128    pub fn get_mut(&mut self) -> &mut S {
129        self.inner.get_mut()
130    }
131
132    /// Consumes the `SyncStream`, returning the underlying stream.
133    pub fn into_inner(self) -> S {
134        self.inner.into_inner()
135    }
136}
137
138pin_project! {
139    /// A write stream wrapper for [`futures_util::io`].
140    ///
141    /// It doesn't support read operations, making looser requirements on the inner stream.
142    pub struct AsyncWriteStream<S> {
143        #[pin]
144        inner: SyncStream<S>,
145        write_future: Option<PinBoxFuture<io::Result<usize>>>,
146        shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
147        write_waker: Option<Waker>,
148        flush_waker: Option<Waker>,
149        close_waker: Option<Waker>,
150        #[pin]
151        _p: PhantomPinned,
152    }
153}
154
155impl<S> AsyncWriteStream<S> {
156    /// Create [`AsyncWriteStream`] with the stream and default buffer size.
157    pub fn new(stream: S) -> Self {
158        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
159    }
160
161    /// Create [`AsyncWriteStream`] with the stream and buffer size.
162    pub fn with_capacity(cap: usize, stream: S) -> Self {
163        Self::new_impl(SyncStream::with_limits2(
164            0,
165            cap,
166            cap,
167            SyncStream::<S>::DEFAULT_MAX_BUFFER,
168            stream,
169        ))
170    }
171
172    fn new_impl(inner: SyncStream<S>) -> Self {
173        Self {
174            inner,
175            write_future: None,
176            shutdown_future: None,
177            write_waker: None,
178            flush_waker: None,
179            close_waker: None,
180            _p: PhantomPinned,
181        }
182    }
183
184    /// Get the reference of the inner stream.
185    pub fn get_ref(&self) -> &S {
186        self.inner.get_ref()
187    }
188
189    /// Returns a mutable reference to the underlying stream.
190    pub fn get_mut(&mut self) -> &mut S {
191        self.inner.get_mut()
192    }
193
194    /// Consumes the `SyncStream`, returning the underlying stream.
195    pub fn into_inner(self) -> S {
196        self.inner.into_inner()
197    }
198}
199
200macro_rules! poll_future {
201    ($f:expr, $cx:expr, $e:expr) => {{
202        let mut future = match $f.take() {
203            Some(f) => f,
204            None => Box::pin($e),
205        };
206        let f = future.as_mut();
207        match f.poll($cx) {
208            Poll::Pending => {
209                $f.replace(future);
210                return Poll::Pending;
211            }
212            Poll::Ready(res) => res,
213        }
214    }};
215}
216
217macro_rules! poll_future_would_block {
218    ($cx:expr, $w:expr, $io:expr, $f:expr) => {{
219        match $io {
220            Ok(res) => {
221                $w.take();
222                return Poll::Ready(Ok(res));
223            }
224            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
225                ready!($f)?;
226            }
227            Err(e) => {
228                $w.take();
229                return Poll::Ready(Err(e));
230            }
231        }
232    }};
233}
234
235unsafe fn extend_lifetime_mut<T: ?Sized>(t: &mut T) -> &'static mut T {
236    unsafe { &mut *(t as *mut T) }
237}
238
239unsafe fn extend_lifetime<T: ?Sized>(t: &T) -> &'static T {
240    unsafe { &*(t as *const T) }
241}
242
243fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
244    if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
245        waker_slot.replace(waker.clone());
246    }
247}
248
249impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
250where
251    for<'a> &'a S: AsyncRead,
252{
253    fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
254        let this = self.project();
255        // SAFETY:
256        // - The future won't live longer than the stream.
257        // - The stream is internally mutable.
258        // - The future only accesses the corresponding buffer and fields.
259        //   - No access overlap between the futures.
260        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
261        let arr = WakerArray([
262            this.read_waker.as_ref().cloned(),
263            this.read_uninit_waker.as_ref().cloned(),
264            this.read_buf_waker.as_ref().cloned(),
265        ]);
266        let waker = Waker::from(Arc::new(arr));
267        let cx = &mut Context::from_waker(&waker);
268        let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
269        Poll::Ready(res)
270    }
271}
272
273impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncStream<S>
274where
275    for<'a> &'a S: AsyncRead,
276{
277    fn poll_read(
278        mut self: Pin<&mut Self>,
279        cx: &mut Context<'_>,
280        buf: &mut [u8],
281    ) -> Poll<io::Result<usize>> {
282        replace_waker(self.as_mut().project().read_waker, cx.waker());
283        loop {
284            let this = self.as_mut().project();
285            poll_future_would_block!(
286                cx,
287                this.read_waker,
288                io::Read::read(this.inner.get_mut(), buf),
289                self.as_mut().poll_read_impl()
290            )
291        }
292    }
293}
294
295impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
296where
297    for<'a> &'a S: AsyncRead,
298{
299    /// Attempt to read from the `AsyncRead` into `buf`.
300    ///
301    /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
302    pub fn poll_read_uninit(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context<'_>,
305        buf: &mut [MaybeUninit<u8>],
306    ) -> Poll<io::Result<usize>> {
307        replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
308        loop {
309            let this = self.as_mut().project();
310            poll_future_would_block!(
311                cx,
312                this.read_uninit_waker,
313                this.inner.get_mut().read_buf_uninit(buf),
314                self.as_mut().poll_read_impl()
315            )
316        }
317    }
318}
319
320impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncStream<S>
321where
322    for<'a> &'a S: AsyncRead,
323{
324    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
325        replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
326        loop {
327            let this = self.as_mut().project();
328            poll_future_would_block!(
329                cx,
330                this.read_buf_waker,
331                // SAFETY: The buffer won't be accessed after the future is ready, and the future
332                // won't live longer than the stream.
333                io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
334                self.as_mut().poll_read_impl()
335            )
336        }
337    }
338
339    fn consume(self: Pin<&mut Self>, amt: usize) {
340        self.project().inner.consume(amt)
341    }
342}
343
344impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
345    fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
346        let this = self.project();
347        // SAFETY:
348        // - The future won't live longer than the stream.
349        // - The stream is `Unpin`.
350        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
351        let arr = WakerArray([
352            this.read_waker.as_ref().cloned(),
353            this.read_uninit_waker.as_ref().cloned(),
354            this.read_buf_waker.as_ref().cloned(),
355        ]);
356        let waker = Waker::from(Arc::new(arr));
357        let cx = &mut Context::from_waker(&waker);
358        let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
359        Poll::Ready(res)
360    }
361}
362
363impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncReadStream<S> {
364    fn poll_read(
365        mut self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367        buf: &mut [u8],
368    ) -> Poll<io::Result<usize>> {
369        replace_waker(self.as_mut().project().read_waker, cx.waker());
370        loop {
371            let this = self.as_mut().project();
372            poll_future_would_block!(
373                cx,
374                this.read_waker,
375                io::Read::read(this.inner.get_mut(), buf),
376                self.as_mut().poll_read_impl()
377            )
378        }
379    }
380}
381
382impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
383    /// Attempt to read from the `AsyncRead` into `buf`.
384    ///
385    /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
386    pub fn poll_read_uninit(
387        mut self: Pin<&mut Self>,
388        cx: &mut Context<'_>,
389        buf: &mut [MaybeUninit<u8>],
390    ) -> Poll<io::Result<usize>> {
391        replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
392        loop {
393            let this = self.as_mut().project();
394            poll_future_would_block!(
395                cx,
396                this.read_uninit_waker,
397                this.inner.get_mut().read_buf_uninit(buf),
398                self.as_mut().poll_read_impl()
399            )
400        }
401    }
402}
403impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncReadStream<S> {
404    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
405        replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
406        loop {
407            let this = self.as_mut().project();
408            poll_future_would_block!(
409                cx,
410                this.read_buf_waker,
411                // SAFETY: The buffer won't be accessed after the future is ready, and the future
412                // won't live longer than the stream.
413                io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
414                self.as_mut().poll_read_impl()
415            )
416        }
417    }
418
419    fn consume(self: Pin<&mut Self>, amt: usize) {
420        self.project().inner.consume(amt)
421    }
422}
423
424impl<S: AsyncWrite + Unpin + 'static> AsyncStream<S>
425where
426    for<'a> &'a S: AsyncWrite,
427{
428    fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
429        let this = self.project();
430        // SAFETY:
431        // - The future won't live longer than the stream.
432        // - The stream is internally mutable.
433        // - The future only accesses the corresponding buffer and fields.
434        //   - No access overlap between the futures.
435        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
436        let arr = WakerArray([
437            this.write_waker.as_ref().cloned(),
438            this.flush_waker.as_ref().cloned(),
439            this.close_waker.as_ref().cloned(),
440        ]);
441        let waker = Waker::from(Arc::new(arr));
442        let cx = &mut Context::from_waker(&waker);
443        let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
444        Poll::Ready(res)
445    }
446
447    fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
448        let this = self.project();
449        // SAFETY:
450        // - The future won't live longer than the stream.
451        // - The stream is internally mutable.
452        // - The future only accesses the corresponding buffer and fields.
453        //   - No access overlap between the futures.
454        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
455        let arr = WakerArray([
456            this.write_waker.as_ref().cloned(),
457            this.flush_waker.as_ref().cloned(),
458            this.close_waker.as_ref().cloned(),
459        ]);
460        let waker = Waker::from(Arc::new(arr));
461        let cx = &mut Context::from_waker(&waker);
462        let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
463        Poll::Ready(res)
464    }
465}
466
467impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncStream<S>
468where
469    for<'a> &'a S: AsyncWrite,
470{
471    fn poll_write(
472        mut self: Pin<&mut Self>,
473        cx: &mut Context<'_>,
474        buf: &[u8],
475    ) -> Poll<io::Result<usize>> {
476        replace_waker(self.as_mut().project().write_waker, cx.waker());
477        if self.shutdown_future.is_some() {
478            debug_assert!(self.write_future.is_none());
479            ready!(self.as_mut().poll_close_impl())?;
480        }
481        loop {
482            let this = self.as_mut().project();
483            poll_future_would_block!(
484                cx,
485                this.write_waker,
486                io::Write::write(this.inner.get_mut(), buf),
487                self.as_mut().poll_flush_impl()
488            )
489        }
490    }
491
492    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
493        replace_waker(self.as_mut().project().flush_waker, cx.waker());
494        if self.shutdown_future.is_some() {
495            debug_assert!(self.write_future.is_none());
496            ready!(self.as_mut().poll_close_impl())?;
497        }
498        let res = ready!(self.as_mut().poll_flush_impl());
499        self.project().flush_waker.take();
500        Poll::Ready(res.map(|_| ()))
501    }
502
503    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
504        replace_waker(self.as_mut().project().close_waker, cx.waker());
505        // Avoid shutdown on flush because the inner buffer might be passed to the
506        // driver.
507        if self.write_future.is_some() || self.inner.has_pending_write() {
508            debug_assert!(self.shutdown_future.is_none());
509            ready!(self.as_mut().poll_flush_impl())?;
510        }
511        let res = ready!(self.as_mut().poll_close_impl());
512        self.project().close_waker.take();
513        Poll::Ready(res)
514    }
515}
516
517impl<S: AsyncWrite + Unpin + 'static> AsyncWriteStream<S> {
518    fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
519        let this = self.project();
520        // SAFETY:
521        // - The future won't live longer than the stream.
522        // - The stream is `Unpin`.
523        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
524        let arr = WakerArray([
525            this.write_waker.as_ref().cloned(),
526            this.flush_waker.as_ref().cloned(),
527            this.close_waker.as_ref().cloned(),
528        ]);
529        let waker = Waker::from(Arc::new(arr));
530        let cx = &mut Context::from_waker(&waker);
531        let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
532        Poll::Ready(res)
533    }
534
535    fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
536        let this = self.project();
537        // SAFETY:
538        // - The future won't live longer than the stream.
539        // - The stream is `Unpin`.
540        let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
541        let arr = WakerArray([
542            this.write_waker.as_ref().cloned(),
543            this.flush_waker.as_ref().cloned(),
544            this.close_waker.as_ref().cloned(),
545        ]);
546        let waker = Waker::from(Arc::new(arr));
547        let cx = &mut Context::from_waker(&waker);
548        let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
549        Poll::Ready(res)
550    }
551}
552
553impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncWriteStream<S> {
554    fn poll_write(
555        mut self: Pin<&mut Self>,
556        cx: &mut Context<'_>,
557        buf: &[u8],
558    ) -> Poll<io::Result<usize>> {
559        replace_waker(self.as_mut().project().write_waker, cx.waker());
560        if self.shutdown_future.is_some() {
561            debug_assert!(self.write_future.is_none());
562            ready!(self.as_mut().poll_close_impl())?;
563        }
564        loop {
565            let this = self.as_mut().project();
566            poll_future_would_block!(
567                cx,
568                this.write_waker,
569                io::Write::write(this.inner.get_mut(), buf),
570                self.as_mut().poll_flush_impl()
571            )
572        }
573    }
574
575    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
576        replace_waker(self.as_mut().project().flush_waker, cx.waker());
577        if self.shutdown_future.is_some() {
578            debug_assert!(self.write_future.is_none());
579            ready!(self.as_mut().poll_close_impl())?;
580        }
581        let res = ready!(self.as_mut().poll_flush_impl());
582        self.project().flush_waker.take();
583        Poll::Ready(res.map(|_| ()))
584    }
585
586    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
587        replace_waker(self.as_mut().project().close_waker, cx.waker());
588        // Avoid shutdown on flush because the inner buffer might be passed to the
589        // driver.
590        if self.write_future.is_some() || self.inner.has_pending_write() {
591            debug_assert!(self.shutdown_future.is_none());
592            ready!(self.as_mut().poll_flush_impl())?;
593        }
594        let res = ready!(self.as_mut().poll_close_impl());
595        self.project().close_waker.take();
596        Poll::Ready(res)
597    }
598}
599
600impl<S: Debug> Debug for AsyncStream<S> {
601    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602        f.debug_struct("AsyncStream")
603            .field("inner", &self.inner)
604            .finish_non_exhaustive()
605    }
606}
607
608struct WakerArray<const N: usize>([Option<Waker>; N]);
609
610impl<const N: usize> Wake for WakerArray<N> {
611    fn wake(self: Arc<Self>) {
612        self.0.iter().for_each(|w| {
613            if let Some(w) = w {
614                w.wake_by_ref()
615            }
616        });
617    }
618}
619
620#[cfg(test)]
621mod test {
622    use futures_executor::block_on;
623    use futures_util::AsyncWriteExt;
624
625    use super::AsyncWriteStream;
626
627    #[test]
628    fn close() {
629        block_on(async {
630            let stream = AsyncWriteStream::new(Vec::<u8>::new());
631            let mut stream = std::pin::pin!(stream);
632            let n = stream.write(b"hello").await.unwrap();
633            assert_eq!(n, 5);
634            stream.close().await.unwrap();
635            assert_eq!(stream.get_ref(), b"hello");
636        })
637    }
638}