1use std::{
2 fmt::Debug,
3 io::{self, BufRead},
4 marker::PhantomPinned,
5 mem::MaybeUninit,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll, Wake, Waker, ready},
9};
10
11use pin_project_lite::pin_project;
12
13use crate::{AsyncRead, AsyncWrite, PinBoxFuture, compat::SyncStream, util::DEFAULT_BUF_SIZE};
14
15pin_project! {
16 pub struct AsyncStream<S> {
18 #[pin]
19 inner: SyncStream<S>,
20 read_future: Option<PinBoxFuture<io::Result<usize>>>,
21 write_future: Option<PinBoxFuture<io::Result<usize>>>,
22 shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
23 read_waker: Option<Waker>,
24 read_uninit_waker: Option<Waker>,
25 read_buf_waker: Option<Waker>,
26 write_waker: Option<Waker>,
27 flush_waker: Option<Waker>,
28 close_waker: Option<Waker>,
29 #[pin]
30 _p: PhantomPinned,
31 }
32}
33
34impl<S> AsyncStream<S> {
35 pub fn new(stream: S) -> Self {
37 Self::new_impl(SyncStream::new(stream))
38 }
39
40 pub fn with_capacity(cap: usize, stream: S) -> Self {
42 Self::new_impl(SyncStream::with_capacity(cap, stream))
43 }
44
45 fn new_impl(inner: SyncStream<S>) -> Self {
46 Self {
47 inner,
48 read_future: None,
49 write_future: None,
50 shutdown_future: None,
51 read_waker: None,
52 read_uninit_waker: None,
53 read_buf_waker: None,
54 write_waker: None,
55 flush_waker: None,
56 close_waker: None,
57 _p: PhantomPinned,
58 }
59 }
60
61 pub fn get_ref(&self) -> &S {
63 self.inner.get_ref()
64 }
65
66 pub fn get_mut(&mut self) -> &mut S {
68 self.inner.get_mut()
69 }
70
71 pub fn into_inner(self) -> S {
73 self.inner.into_inner()
74 }
75}
76
77pin_project! {
78 pub struct AsyncReadStream<S> {
83 #[pin]
84 inner: SyncStream<S>,
85 read_future: Option<PinBoxFuture<io::Result<usize>>>,
86 read_waker: Option<Waker>,
87 read_uninit_waker: Option<Waker>,
88 read_buf_waker: Option<Waker>,
89 #[pin]
90 _p: PhantomPinned,
91 }
92}
93
94impl<S> AsyncReadStream<S> {
95 pub fn new(stream: S) -> Self {
97 Self::with_capacity(DEFAULT_BUF_SIZE, stream)
98 }
99
100 pub fn with_capacity(cap: usize, stream: S) -> Self {
102 Self::new_impl(SyncStream::with_limits2(
103 cap,
104 0,
105 cap,
106 SyncStream::<S>::DEFAULT_MAX_BUFFER,
107 stream,
108 ))
109 }
110
111 fn new_impl(inner: SyncStream<S>) -> Self {
112 Self {
113 inner,
114 read_future: None,
115 read_waker: None,
116 read_uninit_waker: None,
117 read_buf_waker: None,
118 _p: PhantomPinned,
119 }
120 }
121
122 pub fn get_ref(&self) -> &S {
124 self.inner.get_ref()
125 }
126
127 pub fn get_mut(&mut self) -> &mut S {
129 self.inner.get_mut()
130 }
131
132 pub fn into_inner(self) -> S {
134 self.inner.into_inner()
135 }
136}
137
138pin_project! {
139 pub struct AsyncWriteStream<S> {
143 #[pin]
144 inner: SyncStream<S>,
145 write_future: Option<PinBoxFuture<io::Result<usize>>>,
146 shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
147 write_waker: Option<Waker>,
148 flush_waker: Option<Waker>,
149 close_waker: Option<Waker>,
150 #[pin]
151 _p: PhantomPinned,
152 }
153}
154
155impl<S> AsyncWriteStream<S> {
156 pub fn new(stream: S) -> Self {
158 Self::with_capacity(DEFAULT_BUF_SIZE, stream)
159 }
160
161 pub fn with_capacity(cap: usize, stream: S) -> Self {
163 Self::new_impl(SyncStream::with_limits2(
164 0,
165 cap,
166 cap,
167 SyncStream::<S>::DEFAULT_MAX_BUFFER,
168 stream,
169 ))
170 }
171
172 fn new_impl(inner: SyncStream<S>) -> Self {
173 Self {
174 inner,
175 write_future: None,
176 shutdown_future: None,
177 write_waker: None,
178 flush_waker: None,
179 close_waker: None,
180 _p: PhantomPinned,
181 }
182 }
183
184 pub fn get_ref(&self) -> &S {
186 self.inner.get_ref()
187 }
188
189 pub fn get_mut(&mut self) -> &mut S {
191 self.inner.get_mut()
192 }
193
194 pub fn into_inner(self) -> S {
196 self.inner.into_inner()
197 }
198}
199
200macro_rules! poll_future {
201 ($f:expr, $cx:expr, $e:expr) => {{
202 let mut future = match $f.take() {
203 Some(f) => f,
204 None => Box::pin($e),
205 };
206 let f = future.as_mut();
207 match f.poll($cx) {
208 Poll::Pending => {
209 $f.replace(future);
210 return Poll::Pending;
211 }
212 Poll::Ready(res) => res,
213 }
214 }};
215}
216
217macro_rules! poll_future_would_block {
218 ($cx:expr, $w:expr, $io:expr, $f:expr) => {{
219 match $io {
220 Ok(res) => {
221 $w.take();
222 return Poll::Ready(Ok(res));
223 }
224 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
225 ready!($f)?;
226 }
227 Err(e) => {
228 $w.take();
229 return Poll::Ready(Err(e));
230 }
231 }
232 }};
233}
234
235unsafe fn extend_lifetime_mut<T: ?Sized>(t: &mut T) -> &'static mut T {
236 unsafe { &mut *(t as *mut T) }
237}
238
239unsafe fn extend_lifetime<T: ?Sized>(t: &T) -> &'static T {
240 unsafe { &*(t as *const T) }
241}
242
243fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
244 if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
245 waker_slot.replace(waker.clone());
246 }
247}
248
249impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
250where
251 for<'a> &'a S: AsyncRead,
252{
253 fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
254 let this = self.project();
255 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
261 let arr = WakerArray([
262 this.read_waker.as_ref().cloned(),
263 this.read_uninit_waker.as_ref().cloned(),
264 this.read_buf_waker.as_ref().cloned(),
265 ]);
266 let waker = Waker::from(Arc::new(arr));
267 let cx = &mut Context::from_waker(&waker);
268 let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
269 Poll::Ready(res)
270 }
271}
272
273impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncStream<S>
274where
275 for<'a> &'a S: AsyncRead,
276{
277 fn poll_read(
278 mut self: Pin<&mut Self>,
279 cx: &mut Context<'_>,
280 buf: &mut [u8],
281 ) -> Poll<io::Result<usize>> {
282 replace_waker(self.as_mut().project().read_waker, cx.waker());
283 loop {
284 let this = self.as_mut().project();
285 poll_future_would_block!(
286 cx,
287 this.read_waker,
288 io::Read::read(this.inner.get_mut(), buf),
289 self.as_mut().poll_read_impl()
290 )
291 }
292 }
293}
294
295impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
296where
297 for<'a> &'a S: AsyncRead,
298{
299 pub fn poll_read_uninit(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context<'_>,
305 buf: &mut [MaybeUninit<u8>],
306 ) -> Poll<io::Result<usize>> {
307 replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
308 loop {
309 let this = self.as_mut().project();
310 poll_future_would_block!(
311 cx,
312 this.read_uninit_waker,
313 this.inner.get_mut().read_buf_uninit(buf),
314 self.as_mut().poll_read_impl()
315 )
316 }
317 }
318}
319
320impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncStream<S>
321where
322 for<'a> &'a S: AsyncRead,
323{
324 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
325 replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
326 loop {
327 let this = self.as_mut().project();
328 poll_future_would_block!(
329 cx,
330 this.read_buf_waker,
331 io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
334 self.as_mut().poll_read_impl()
335 )
336 }
337 }
338
339 fn consume(self: Pin<&mut Self>, amt: usize) {
340 self.project().inner.consume(amt)
341 }
342}
343
344impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
345 fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
346 let this = self.project();
347 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
351 let arr = WakerArray([
352 this.read_waker.as_ref().cloned(),
353 this.read_uninit_waker.as_ref().cloned(),
354 this.read_buf_waker.as_ref().cloned(),
355 ]);
356 let waker = Waker::from(Arc::new(arr));
357 let cx = &mut Context::from_waker(&waker);
358 let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
359 Poll::Ready(res)
360 }
361}
362
363impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncReadStream<S> {
364 fn poll_read(
365 mut self: Pin<&mut Self>,
366 cx: &mut Context<'_>,
367 buf: &mut [u8],
368 ) -> Poll<io::Result<usize>> {
369 replace_waker(self.as_mut().project().read_waker, cx.waker());
370 loop {
371 let this = self.as_mut().project();
372 poll_future_would_block!(
373 cx,
374 this.read_waker,
375 io::Read::read(this.inner.get_mut(), buf),
376 self.as_mut().poll_read_impl()
377 )
378 }
379 }
380}
381
382impl<S: AsyncRead + Unpin + 'static> AsyncReadStream<S> {
383 pub fn poll_read_uninit(
387 mut self: Pin<&mut Self>,
388 cx: &mut Context<'_>,
389 buf: &mut [MaybeUninit<u8>],
390 ) -> Poll<io::Result<usize>> {
391 replace_waker(self.as_mut().project().read_uninit_waker, cx.waker());
392 loop {
393 let this = self.as_mut().project();
394 poll_future_would_block!(
395 cx,
396 this.read_uninit_waker,
397 this.inner.get_mut().read_buf_uninit(buf),
398 self.as_mut().poll_read_impl()
399 )
400 }
401 }
402}
403impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncReadStream<S> {
404 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
405 replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
406 loop {
407 let this = self.as_mut().project();
408 poll_future_would_block!(
409 cx,
410 this.read_buf_waker,
411 io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
414 self.as_mut().poll_read_impl()
415 )
416 }
417 }
418
419 fn consume(self: Pin<&mut Self>, amt: usize) {
420 self.project().inner.consume(amt)
421 }
422}
423
424impl<S: AsyncWrite + Unpin + 'static> AsyncStream<S>
425where
426 for<'a> &'a S: AsyncWrite,
427{
428 fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
429 let this = self.project();
430 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
436 let arr = WakerArray([
437 this.write_waker.as_ref().cloned(),
438 this.flush_waker.as_ref().cloned(),
439 this.close_waker.as_ref().cloned(),
440 ]);
441 let waker = Waker::from(Arc::new(arr));
442 let cx = &mut Context::from_waker(&waker);
443 let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
444 Poll::Ready(res)
445 }
446
447 fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
448 let this = self.project();
449 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
455 let arr = WakerArray([
456 this.write_waker.as_ref().cloned(),
457 this.flush_waker.as_ref().cloned(),
458 this.close_waker.as_ref().cloned(),
459 ]);
460 let waker = Waker::from(Arc::new(arr));
461 let cx = &mut Context::from_waker(&waker);
462 let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
463 Poll::Ready(res)
464 }
465}
466
467impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncStream<S>
468where
469 for<'a> &'a S: AsyncWrite,
470{
471 fn poll_write(
472 mut self: Pin<&mut Self>,
473 cx: &mut Context<'_>,
474 buf: &[u8],
475 ) -> Poll<io::Result<usize>> {
476 replace_waker(self.as_mut().project().write_waker, cx.waker());
477 if self.shutdown_future.is_some() {
478 debug_assert!(self.write_future.is_none());
479 ready!(self.as_mut().poll_close_impl())?;
480 }
481 loop {
482 let this = self.as_mut().project();
483 poll_future_would_block!(
484 cx,
485 this.write_waker,
486 io::Write::write(this.inner.get_mut(), buf),
487 self.as_mut().poll_flush_impl()
488 )
489 }
490 }
491
492 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
493 replace_waker(self.as_mut().project().flush_waker, cx.waker());
494 if self.shutdown_future.is_some() {
495 debug_assert!(self.write_future.is_none());
496 ready!(self.as_mut().poll_close_impl())?;
497 }
498 let res = ready!(self.as_mut().poll_flush_impl());
499 self.project().flush_waker.take();
500 Poll::Ready(res.map(|_| ()))
501 }
502
503 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
504 replace_waker(self.as_mut().project().close_waker, cx.waker());
505 if self.write_future.is_some() || self.inner.has_pending_write() {
508 debug_assert!(self.shutdown_future.is_none());
509 ready!(self.as_mut().poll_flush_impl())?;
510 }
511 let res = ready!(self.as_mut().poll_close_impl());
512 self.project().close_waker.take();
513 Poll::Ready(res)
514 }
515}
516
517impl<S: AsyncWrite + Unpin + 'static> AsyncWriteStream<S> {
518 fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
519 let this = self.project();
520 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
524 let arr = WakerArray([
525 this.write_waker.as_ref().cloned(),
526 this.flush_waker.as_ref().cloned(),
527 this.close_waker.as_ref().cloned(),
528 ]);
529 let waker = Waker::from(Arc::new(arr));
530 let cx = &mut Context::from_waker(&waker);
531 let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
532 Poll::Ready(res)
533 }
534
535 fn poll_close_impl(self: Pin<&mut Self>) -> Poll<io::Result<()>> {
536 let this = self.project();
537 let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
541 let arr = WakerArray([
542 this.write_waker.as_ref().cloned(),
543 this.flush_waker.as_ref().cloned(),
544 this.close_waker.as_ref().cloned(),
545 ]);
546 let waker = Waker::from(Arc::new(arr));
547 let cx = &mut Context::from_waker(&waker);
548 let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
549 Poll::Ready(res)
550 }
551}
552
553impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncWriteStream<S> {
554 fn poll_write(
555 mut self: Pin<&mut Self>,
556 cx: &mut Context<'_>,
557 buf: &[u8],
558 ) -> Poll<io::Result<usize>> {
559 replace_waker(self.as_mut().project().write_waker, cx.waker());
560 if self.shutdown_future.is_some() {
561 debug_assert!(self.write_future.is_none());
562 ready!(self.as_mut().poll_close_impl())?;
563 }
564 loop {
565 let this = self.as_mut().project();
566 poll_future_would_block!(
567 cx,
568 this.write_waker,
569 io::Write::write(this.inner.get_mut(), buf),
570 self.as_mut().poll_flush_impl()
571 )
572 }
573 }
574
575 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
576 replace_waker(self.as_mut().project().flush_waker, cx.waker());
577 if self.shutdown_future.is_some() {
578 debug_assert!(self.write_future.is_none());
579 ready!(self.as_mut().poll_close_impl())?;
580 }
581 let res = ready!(self.as_mut().poll_flush_impl());
582 self.project().flush_waker.take();
583 Poll::Ready(res.map(|_| ()))
584 }
585
586 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
587 replace_waker(self.as_mut().project().close_waker, cx.waker());
588 if self.write_future.is_some() || self.inner.has_pending_write() {
591 debug_assert!(self.shutdown_future.is_none());
592 ready!(self.as_mut().poll_flush_impl())?;
593 }
594 let res = ready!(self.as_mut().poll_close_impl());
595 self.project().close_waker.take();
596 Poll::Ready(res)
597 }
598}
599
600impl<S: Debug> Debug for AsyncStream<S> {
601 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 f.debug_struct("AsyncStream")
603 .field("inner", &self.inner)
604 .finish_non_exhaustive()
605 }
606}
607
608struct WakerArray<const N: usize>([Option<Waker>; N]);
609
610impl<const N: usize> Wake for WakerArray<N> {
611 fn wake(self: Arc<Self>) {
612 self.0.iter().for_each(|w| {
613 if let Some(w) = w {
614 w.wake_by_ref()
615 }
616 });
617 }
618}
619
620#[cfg(test)]
621mod test {
622 use futures_executor::block_on;
623 use futures_util::AsyncWriteExt;
624
625 use super::AsyncWriteStream;
626
627 #[test]
628 fn close() {
629 block_on(async {
630 let stream = AsyncWriteStream::new(Vec::<u8>::new());
631 let mut stream = std::pin::pin!(stream);
632 let n = stream.write(b"hello").await.unwrap();
633 assert_eq!(n, 5);
634 stream.close().await.unwrap();
635 assert_eq!(stream.get_ref(), b"hello");
636 })
637 }
638}