1use std::{
2 fmt::Debug,
3 io,
4 marker::PhantomPinned,
5 mem::MaybeUninit,
6 pin::Pin,
7 task::{Context, Poll, Waker, ready},
8};
9
10use pin_project_lite::pin_project;
11
12use super::waker_array::WakerArrayRef;
13use crate::{
14 AsyncRead, AsyncWrite, PinBoxFuture,
15 compat::{SyncStream, SyncStreamReadHalf, SyncStreamWriteHalf},
16 util::{DEFAULT_BUF_SIZE, Splittable},
17};
18
19pin_project! {
20 pub struct AsyncStream<S: Splittable> {
22 #[pin]
23 read_inner: AsyncReadStream<S::ReadHalf>,
24 #[pin]
25 write_inner: AsyncWriteStream<S::WriteHalf>,
26 #[pin]
27 _p: PhantomPinned,
28 }
29}
30
31impl<S: Splittable> AsyncStream<S> {
32 pub fn new(stream: S) -> Self {
34 Self::new_impl(SyncStream::new(stream))
35 }
36
37 pub fn with_capacity(cap: usize, stream: S) -> Self {
39 Self::new_impl(SyncStream::with_capacity(cap, stream))
40 }
41
42 pub fn with_limits(cap: usize, max_buffer_size: usize, stream: S) -> Self {
45 Self::new_impl(SyncStream::with_limits(cap, max_buffer_size, stream))
46 }
47
48 fn new_impl(inner: SyncStream<S>) -> Self {
49 let (read_inner, write_inner) = inner.split();
50 Self {
51 read_inner: AsyncReadStream::new_impl(read_inner),
52 write_inner: AsyncWriteStream::new_impl(write_inner),
53 _p: PhantomPinned,
54 }
55 }
56
57 pub fn get_ref(&self) -> (&S::ReadHalf, &S::WriteHalf) {
59 (self.read_inner.get_ref(), self.write_inner.get_ref())
60 }
61
62 pub fn get_mut(&mut self) -> (&mut S::ReadHalf, &mut S::WriteHalf) {
64 (self.read_inner.get_mut(), self.write_inner.get_mut())
65 }
66
67 pub fn into_inner(self) -> (S::ReadHalf, S::WriteHalf) {
69 (self.read_inner.into_inner(), self.write_inner.into_inner())
70 }
71}
72
73pin_project! {
74 pub struct AsyncReadStream<S> {
79 #[pin]
80 inner: SyncStreamReadHalf<S>,
81 read_future: Option<PinBoxFuture<io::Result<usize>>>,
82 read_waker: Option<Waker>,
83 read_uninit_waker: Option<Waker>,
84 read_buf_waker: Option<Waker>,
85 #[pin]
86 _p: PhantomPinned,
87 }
88}
89
90impl<S> AsyncReadStream<S> {
91 pub fn new(stream: S) -> Self {
93 Self::with_capacity(DEFAULT_BUF_SIZE, stream)
94 }
95
96 pub fn with_capacity(cap: usize, stream: S) -> Self {
98 Self::new_impl(SyncStreamReadHalf::with_limits(
99 cap,
100 super::DEFAULT_MAX_BUFFER,
101 stream,
102 ))
103 }
104
105 fn new_impl(inner: SyncStreamReadHalf<S>) -> Self {
106 Self {
107 inner,
108 read_future: None,
109 read_waker: None,
110 read_uninit_waker: None,
111 read_buf_waker: None,
112 _p: PhantomPinned,
113 }
114 }
115
116 pub fn get_ref(&self) -> &S {
118 self.inner.get_ref()
119 }
120
121 pub fn get_mut(&mut self) -> &mut S {
123 self.inner.get_mut()
124 }
125
126 pub fn into_inner(self) -> S {
128 self.inner.into_inner()
129 }
130}
131
132pin_project! {
133 pub struct AsyncWriteStream<S> {
137 #[pin]
138 inner: SyncStreamWriteHalf<S>,
139 write_future: Option<PinBoxFuture<io::Result<usize>>>,
140 shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
141 write_waker: Option<Waker>,
142 flush_waker: Option<Waker>,
143 close_waker: Option<Waker>,
144 closed: bool,
145 #[pin]
146 _p: PhantomPinned,
147 }
148}
149
150impl<S> AsyncWriteStream<S> {
151 pub fn new(stream: S) -> Self {
153 Self::with_capacity(DEFAULT_BUF_SIZE, stream)
154 }
155
156 pub fn with_capacity(cap: usize, stream: S) -> Self {
158 Self::new_impl(SyncStreamWriteHalf::with_limits(
159 cap,
160 super::DEFAULT_MAX_BUFFER,
161 stream,
162 ))
163 }
164
165 fn new_impl(inner: SyncStreamWriteHalf<S>) -> Self {
166 Self {
167 inner,
168 write_future: None,
169 shutdown_future: None,
170 write_waker: None,
171 flush_waker: None,
172 close_waker: None,
173 closed: false,
174 _p: PhantomPinned,
175 }
176 }
177
178 pub fn get_ref(&self) -> &S {
180 self.inner.get_ref()
181 }
182
183 pub fn get_mut(&mut self) -> &mut S {
185 self.inner.get_mut()
186 }
187
188 pub fn into_inner(self) -> S {
190 self.inner.into_inner()
191 }
192}
193
194macro_rules! poll_future {
195 ($f:expr, $cx:expr, $e:expr) => {{
196 let mut future = match $f.take() {
197 Some(f) => f,
198 None => Box::pin($e),
199 };
200 let f = future.as_mut();
201 match f.poll($cx) {
202 Poll::Pending => {
203 $f.replace(future);
204 return Poll::Pending;
205 }
206 Poll::Ready(res) => res,
207 }
208 }};
209}
210
211macro_rules! poll_future_would_block {
212 ($cx:expr, $w:expr, $io:expr, $f:expr) => {{
213 match $io {
214 Ok(res) => {
215 $w.take();
216 return Poll::Ready(Ok(res));
217 }
218 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
219 ready!($f)?;
220 }
221 Err(e) => {
222 $w.take();
223 return Poll::Ready(Err(e));
224 }
225 }
226 }};
227}
228
229unsafe fn extend_lifetime_mut<T: ?Sized>(t: &mut T) -> &'static mut T {
230 unsafe { &mut *(t as *mut T) }
231}
232
233unsafe fn extend_lifetime<T: ?Sized>(t: &T) -> &'static T {
234 unsafe { &*(t as *const T) }
235}
236
237fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
238 if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
239 waker_slot.replace(waker.clone());
240 }
241}
242
243impl<S: Splittable + 'static> futures_util::AsyncRead for AsyncStream<S>
244where
245 S::ReadHalf: AsyncRead + Unpin,
246{
247 fn poll_read(
248 self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut [u8],
251 ) -> Poll<io::Result<usize>> {
252 self.project().read_inner.poll_read(cx, buf)
253 }
254}
255
256impl<S: Splittable + 'static> AsyncStream<S>
257where
258 S::ReadHalf: AsyncRead + Unpin,
259{
260 pub fn poll_read_uninit(
264 self: Pin<&mut Self>,
265 cx: &mut Context<'_>,
266 buf: &mut [MaybeUninit<u8>],
267 ) -> Poll<io::Result<usize>> {
268 self.project().read_inner.poll_read_uninit(cx, buf)
269 }
270}
271
272impl<S: Splittable + 'static> futures_util::AsyncBufRead for AsyncStream<S>
273where
274 S::ReadHalf: AsyncRead + Unpin,
275{
276 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
277 self.project().read_inner.poll_fill_buf(cx)
278 }
279
280 fn consume(self: Pin<&mut Self>, amt: usize) {
281 self.project().read_inner.consume(amt)
282 }
283}
284
285impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
286 fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
287 let this = self.project();
288 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
292 let arr = WakerArrayRef::new([
293 this.read_waker.as_ref(),
294 this.read_uninit_waker.as_ref(),
295 this.read_buf_waker.as_ref(),
296 ]);
297 arr.with(|waker| {
298 let cx = &mut Context::from_waker(waker);
299 let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
300 Poll::Ready(res)
301 })
302 }
303}
304
305impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncReadStream<S> {
306 fn poll_read(
307 mut self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 buf: &mut [u8],
310 ) -> Poll<io::Result<usize>> {
311 replace_waker(self.as_mut().project().read_waker, cx.waker());
312 loop {
313 let this = self.as_mut().project();
314 poll_future_would_block!(
315 cx,
316 this.read_waker,
317 io::Read::read(this.inner.get_mut(), buf),
318 self.as_mut().poll_read_impl()
319 )
320 }
321 }
322}
323
324impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
325 pub fn poll_read_uninit(
329 mut self: Pin<&mut Self>,
330 cx: &mut Context<'_>,
331 buf: &mut [MaybeUninit<u8>],
332 ) -> Poll<io::Result<usize>> {
333 replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
334 loop {
335 let this = self.as_mut().project();
336 poll_future_would_block!(
337 cx,
338 this.read_uninit_waker,
339 this.inner.get_mut().read_buf_uninit(buf),
340 self.as_mut().poll_read_impl()
341 )
342 }
343 }
344}
345impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncReadStream<S> {
346 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
347 replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
348 loop {
349 let this = self.as_mut().project();
350 poll_future_would_block!(
351 cx,
352 this.read_buf_waker,
353 io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
356 self.as_mut().poll_read_impl()
357 )
358 }
359 }
360
361 fn consume(self: Pin<&mut Self>, amt: usize) {
362 io::BufRead::consume(self.project().inner.get_mut(), amt)
363 }
364}
365
366impl<S: Splittable + 'static> futures_util::AsyncWrite for AsyncStream<S>
367where
368 S::WriteHalf: AsyncWrite + Unpin,
369{
370 fn poll_write(
371 self: Pin<&mut Self>,
372 cx: &mut Context<'_>,
373 buf: &[u8],
374 ) -> Poll<io::Result<usize>> {
375 self.project().write_inner.poll_write(cx, buf)
376 }
377
378 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
379 self.project().write_inner.poll_flush(cx)
380 }
381
382 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
383 self.project().write_inner.poll_close(cx)
384 }
385}
386
387impl<S: AsyncWrite + Unpin + 'static> AsyncWriteStream<S> {
388 fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
389 let this = self.project();
390 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
394 let arr = WakerArrayRef::new([
395 this.write_waker.as_ref(),
396 this.flush_waker.as_ref(),
397 this.close_waker.as_ref(),
398 ]);
399 arr.with(|waker| {
400 let cx = &mut Context::from_waker(waker);
401 let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
402 Poll::Ready(res)
403 })
404 }
405
406 fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
407 if self.closed {
408 return Poll::Ready(Ok(()));
409 }
410 let this = self.project();
411 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
415 let arr = WakerArrayRef::new([
416 this.write_waker.as_ref(),
417 this.flush_waker.as_ref(),
418 this.close_waker.as_ref(),
419 ]);
420 arr.with(|waker| {
421 let cx = &mut Context::from_waker(waker);
422 let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
423 Poll::Ready(res.inspect(|_| *this.closed = true))
424 })
425 }
426}
427
428impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncWriteStream<S> {
429 fn poll_write(
430 mut self: Pin<&mut Self>,
431 cx: &mut Context<'_>,
432 buf: &[u8],
433 ) -> Poll<io::Result<usize>> {
434 replace_waker(self.as_mut().project().write_waker, cx.waker());
435 if self.shutdown_future.is_some() {
436 debug_assert!(self.write_future.is_none());
437 ready!(self.as_mut().poll_close_impl())?;
438 }
439 loop {
440 let this = self.as_mut().project();
441 poll_future_would_block!(
442 cx,
443 this.write_waker,
444 io::Write::write(this.inner.get_mut(), buf),
445 self.as_mut().poll_flush_impl()
446 )
447 }
448 }
449
450 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
451 replace_waker(self.as_mut().project().flush_waker, cx.waker());
452 if self.shutdown_future.is_some() {
453 debug_assert!(self.write_future.is_none());
454 ready!(self.as_mut().poll_close_impl())?;
455 }
456 let res = ready!(self.as_mut().poll_flush_impl());
457 self.project().flush_waker.take();
458 Poll::Ready(res.map(|_| ()))
459 }
460
461 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
462 replace_waker(self.as_mut().project().close_waker, cx.waker());
463 if self.write_future.is_some() || self.inner.has_pending_write() {
466 debug_assert!(self.shutdown_future.is_none());
467 ready!(self.as_mut().poll_flush_impl())?;
468 }
469 let res = ready!(self.as_mut().poll_close_impl());
470 self.project().close_waker.take();
471 Poll::Ready(res)
472 }
473}
474
475impl<S: Splittable> Debug for AsyncStream<S> {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 f.debug_struct("AsyncStream").finish_non_exhaustive()
478 }
479}
480
481impl<S: Splittable> Splittable for AsyncStream<S> {
482 type ReadHalf = AsyncReadStream<S::ReadHalf>;
483 type WriteHalf = AsyncWriteStream<S::WriteHalf>;
484
485 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
486 (self.read_inner, self.write_inner)
487 }
488}
489
490#[cfg(test)]
491mod test {
492 use futures_executor::block_on;
493 use futures_util::AsyncWriteExt;
494
495 use super::AsyncWriteStream;
496
497 #[test]
498 fn close() {
499 block_on(async {
500 let stream = AsyncWriteStream::new(Vec::<u8>::new());
501 let mut stream = std::pin::pin!(stream);
502 let n = stream.write(b"hello").await.unwrap();
503 assert_eq!(n, 5);
504 stream.close().await.unwrap();
505 assert_eq!(stream.get_ref(), b"hello");
506 })
507 }
508}