compio_io\compat/
async_stream.rs1use std::{
2 fmt::Debug,
3 io::{self, BufRead},
4 mem::MaybeUninit,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use crate::{PinBoxFuture, compat::SyncStream};
10
11pub struct AsyncStream<S> {
13 inner: Pin<Box<SyncStream<S>>>,
16 read_future: Option<PinBoxFuture<io::Result<usize>>>,
17 write_future: Option<PinBoxFuture<io::Result<usize>>>,
18 shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
19}
20
21impl<S> AsyncStream<S> {
22 pub fn new(stream: S) -> Self {
24 Self::new_impl(SyncStream::new(stream))
25 }
26
27 pub fn with_capacity(cap: usize, stream: S) -> Self {
29 Self::new_impl(SyncStream::with_capacity(cap, stream))
30 }
31
32 fn new_impl(inner: SyncStream<S>) -> Self {
33 Self {
34 inner: Box::pin(inner),
35 read_future: None,
36 write_future: None,
37 shutdown_future: None,
38 }
39 }
40
41 pub fn get_ref(&self) -> &S {
43 self.inner.get_ref()
44 }
45}
46
47macro_rules! poll_future {
48 ($f:expr, $cx:expr, $e:expr) => {{
49 let mut future = match $f.take() {
50 Some(f) => f,
51 None => Box::pin($e),
52 };
53 let f = future.as_mut();
54 match f.poll($cx) {
55 Poll::Pending => {
56 $f.replace(future);
57 return Poll::Pending;
58 }
59 Poll::Ready(res) => res,
60 }
61 }};
62}
63
64macro_rules! poll_future_would_block {
65 ($f:expr, $cx:expr, $e:expr, $io:expr) => {{
66 if let Some(mut f) = $f.take() {
67 if f.as_mut().poll($cx).is_pending() {
68 $f.replace(f);
69 return Poll::Pending;
70 }
71 }
72
73 match $io {
74 Ok(len) => Poll::Ready(Ok(len)),
75 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
76 $f.replace(Box::pin($e));
77 $cx.waker().wake_by_ref();
78 Poll::Pending
79 }
80 Err(e) => Poll::Ready(Err(e)),
81 }
82 }};
83}
84
85impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
86 fn poll_read(
87 mut self: Pin<&mut Self>,
88 cx: &mut Context<'_>,
89 buf: &mut [u8],
90 ) -> Poll<io::Result<usize>> {
91 let inner: &'static mut SyncStream<S> =
95 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
96
97 poll_future_would_block!(
98 self.read_future,
99 cx,
100 inner.fill_read_buf(),
101 io::Read::read(inner, buf)
102 )
103 }
104}
105
106impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
107 pub fn poll_read_uninit(
111 mut self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 buf: &mut [MaybeUninit<u8>],
114 ) -> Poll<io::Result<usize>> {
115 let inner: &'static mut SyncStream<S> =
116 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
117 poll_future_would_block!(
118 self.read_future,
119 cx,
120 inner.fill_read_buf(),
121 inner.read_buf_uninit(buf)
122 )
123 }
124}
125
126impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
127 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
128 let inner: &'static mut SyncStream<S> =
129 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
130 poll_future_would_block!(
131 self.read_future,
132 cx,
133 inner.fill_read_buf(),
134 io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
136 )
137 }
138
139 fn consume(mut self: Pin<&mut Self>, amt: usize) {
140 unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) }
141 }
142}
143
144impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
145 fn poll_write(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &[u8],
149 ) -> Poll<io::Result<usize>> {
150 if self.shutdown_future.is_some() {
151 debug_assert!(self.write_future.is_none());
152 return Poll::Pending;
153 }
154
155 let inner: &'static mut SyncStream<S> =
156 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
157 poll_future_would_block!(
158 self.write_future,
159 cx,
160 inner.flush_write_buf(),
161 io::Write::write(inner, buf)
162 )
163 }
164
165 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 if self.shutdown_future.is_some() {
167 debug_assert!(self.write_future.is_none());
168 return Poll::Pending;
169 }
170
171 let inner: &'static mut SyncStream<S> =
172 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
173 let res = poll_future!(self.write_future, cx, inner.flush_write_buf());
174 Poll::Ready(res.map(|_| ()))
175 }
176
177 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 if self.write_future.is_some() {
181 debug_assert!(self.shutdown_future.is_none());
182 return Poll::Pending;
183 }
184
185 let inner: &'static mut SyncStream<S> =
186 unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
187 let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown());
188 Poll::Ready(res)
189 }
190}
191
192impl<S: Debug> Debug for AsyncStream<S> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("AsyncStream")
195 .field("inner", &self.inner)
196 .finish_non_exhaustive()
197 }
198}