1use std::{
2 io,
3 mem::MaybeUninit,
4 task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBufMut, bytes::Bytes};
8use compio_io::AsyncRead;
9use futures_util::future::poll_fn;
10use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
14
15#[derive(Debug)]
52pub struct RecvStream {
53 conn: Shared<ConnectionInner>,
54 stream: StreamId,
55 is_0rtt: bool,
56 all_data_read: bool,
57 reset: Option<VarInt>,
58}
59
60impl RecvStream {
61 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
62 Self {
63 conn,
64 stream,
65 is_0rtt,
66 all_data_read: false,
67 reset: None,
68 }
69 }
70
71 pub fn id(&self) -> StreamId {
73 self.stream
74 }
75
76 pub fn is_0rtt(&self) -> bool {
82 self.is_0rtt
83 }
84
85 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
91 let mut state = self.conn.state();
92 if self.is_0rtt && !state.check_0rtt() {
93 return Ok(());
94 }
95 state.conn.recv_stream(self.stream).stop(error_code)?;
96 state.wake();
97 self.all_data_read = true;
98 Ok(())
99 }
100
101 pub async fn received_reset(&mut self) -> Result<Option<VarInt>, ResetError> {
112 poll_fn(|cx| {
113 let mut state = self.conn.state();
114
115 if self.is_0rtt && !state.check_0rtt() {
116 return Poll::Ready(Err(ResetError::ZeroRttRejected));
117 }
118 if let Some(code) = self.reset {
119 return Poll::Ready(Ok(Some(code)));
120 }
121
122 match state.conn.recv_stream(self.stream).received_reset() {
123 Err(_) => Poll::Ready(Ok(None)),
124 Ok(Some(error_code)) => {
125 state.wake();
128 Poll::Ready(Ok(Some(error_code)))
129 }
130 Ok(None) => {
131 if let Some(e) = &state.error {
132 return Poll::Ready(Err(e.clone().into()));
133 }
134 state.readable.insert(self.stream, cx.waker().clone());
139 Poll::Pending
140 }
141 }
142 })
143 .await
144 }
145
146 fn execute_poll_read<F, T>(
154 &mut self,
155 cx: &mut Context,
156 ordered: bool,
157 mut read_fn: F,
158 ) -> Poll<Result<Option<T>, ReadError>>
159 where
160 F: FnMut(&mut Chunks) -> ReadStatus<T>,
161 {
162 use quinn_proto::ReadError::*;
163
164 if self.all_data_read {
165 return Poll::Ready(Ok(None));
166 }
167
168 let mut state = self.conn.state();
169 if self.is_0rtt && !state.check_0rtt() {
170 return Poll::Ready(Err(ReadError::ZeroRttRejected));
171 }
172
173 let status = match self.reset {
177 Some(code) => ReadStatus::Failed(None, Reset(code)),
178 None => {
179 let mut recv = state.conn.recv_stream(self.stream);
180 let mut chunks = recv.read(ordered)?;
181 let status = read_fn(&mut chunks);
182 if chunks.finalize().should_transmit() {
183 state.wake();
184 }
185 status
186 }
187 };
188
189 match status {
190 ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
191 ReadStatus::Finished(read) => {
192 self.all_data_read = true;
193 Poll::Ready(Ok(read))
194 }
195 ReadStatus::Failed(read, Blocked) => match read {
196 Some(val) => Poll::Ready(Ok(Some(val))),
197 None => {
198 if let Some(error) = &state.error {
199 return Poll::Ready(Err(error.clone().into()));
200 }
201 state.readable.insert(self.stream, cx.waker().clone());
202 Poll::Pending
203 }
204 },
205 ReadStatus::Failed(read, Reset(error_code)) => match read {
206 None => {
207 self.all_data_read = true;
208 self.reset = Some(error_code);
209 Poll::Ready(Err(ReadError::Reset(error_code)))
210 }
211 done => {
212 self.reset = Some(error_code);
213 Poll::Ready(Ok(done))
214 }
215 },
216 }
217 }
218
219 pub(crate) fn poll_read_impl(
220 &mut self,
221 cx: &mut Context,
222 buf: &mut [MaybeUninit<u8>],
223 ) -> Poll<Result<Option<usize>, ReadError>> {
224 if buf.is_empty() {
225 return Poll::Ready(Ok(Some(0)));
226 }
227
228 self.execute_poll_read(cx, true, |chunks| {
229 let mut read = 0;
230 loop {
231 if read >= buf.len() {
232 return ReadStatus::Readable(read);
234 }
235
236 match chunks.next(buf.len() - read) {
237 Ok(Some(chunk)) => {
238 let bytes = chunk.bytes;
239 let len = bytes.len();
240 buf[read..read + len].copy_from_slice(unsafe {
241 std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
242 });
243 read += len;
244 }
245 res => {
246 return (if read == 0 { None } else { Some(read) }, res.err()).into();
247 }
248 }
249 }
250 })
251 }
252
253 pub fn poll_read_uninit(
266 &mut self,
267 cx: &mut Context,
268 buf: &mut [MaybeUninit<u8>],
269 ) -> Poll<Result<usize, ReadError>> {
270 self.poll_read_impl(cx, buf)
271 .map(|res| res.map(|n| n.unwrap_or_default()))
272 }
273
274 pub async fn read_chunk(
291 &mut self,
292 max_length: usize,
293 ordered: bool,
294 ) -> Result<Option<Chunk>, ReadError> {
295 poll_fn(|cx| {
296 self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
297 Ok(Some(chunk)) => ReadStatus::Readable(chunk),
298 res => (None, res.err()).into(),
299 })
300 })
301 .await
302 }
303
304 pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
315 if bufs.is_empty() {
316 return Ok(Some(0));
317 }
318
319 poll_fn(|cx| {
320 self.execute_poll_read(cx, true, |chunks| {
321 let mut read = 0;
322 loop {
323 if read >= bufs.len() {
324 return ReadStatus::Readable(read);
326 }
327
328 match chunks.next(usize::MAX) {
329 Ok(Some(chunk)) => {
330 bufs[read] = chunk.bytes;
331 read += 1;
332 }
333 res => {
334 return (if read == 0 { None } else { Some(read) }, res.err()).into();
335 }
336 }
337 }
338 })
339 })
340 .await
341 }
342
343 pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
350 let mut start = u64::MAX;
351 let mut end = 0;
352 let mut chunks = vec![];
353 loop {
354 let chunk = match self.read_chunk(usize::MAX, false).await {
355 Ok(Some(chunk)) => chunk,
356 Ok(None) => break,
357 Err(e) => return BufResult(Err(e.into()), buf),
358 };
359 start = start.min(chunk.offset);
360 end = end.max(chunk.offset + chunk.bytes.len() as u64);
361 chunks.push((chunk.offset, chunk.bytes));
362 }
363 if start == u64::MAX || start >= end {
364 return BufResult(Ok(0), buf);
366 }
367 let len = (end - start) as usize;
368 let cap = buf.buf_capacity();
369 let needed = len.saturating_sub(cap);
370 if needed > 0
371 && let Err(e) = buf.reserve(needed)
372 {
373 return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
374 }
375 let slice = &mut buf.as_uninit()[..len];
376 slice.fill(MaybeUninit::new(0));
377 for (offset, bytes) in chunks {
378 let offset = (offset - start) as usize;
379 let buf_len = bytes.len();
380 slice[offset..offset + buf_len].copy_from_slice(unsafe {
381 std::slice::from_raw_parts(bytes.as_ptr().cast(), buf_len)
382 });
383 }
384 unsafe { buf.advance_to(len) }
385 BufResult(Ok(len), buf)
386 }
387
388 #[cfg(feature = "io-compat")]
390 pub fn into_compat(self) -> CompatRecvStream {
391 CompatRecvStream(self)
392 }
393}
394
395impl Drop for RecvStream {
396 fn drop(&mut self) {
397 let mut state = self.conn.state();
398
399 state.readable.remove(&self.stream);
401
402 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
403 return;
404 }
405 if !self.all_data_read {
406 let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
408 state.wake();
409 }
410 }
411}
412
413enum ReadStatus<T> {
414 Readable(T),
415 Finished(Option<T>),
416 Failed(Option<T>, quinn_proto::ReadError),
417}
418
419impl<T> From<(Option<T>, Option<quinn_proto::ReadError>)> for ReadStatus<T> {
420 fn from(status: (Option<T>, Option<quinn_proto::ReadError>)) -> Self {
421 match status {
422 (read, None) => Self::Finished(read),
423 (read, Some(e)) => Self::Failed(read, e),
424 }
425 }
426}
427
428#[derive(Debug, Error, Clone, PartialEq, Eq)]
430pub enum ReadError {
431 #[error("stream reset by peer: error {0}")]
435 Reset(VarInt),
436 #[error("connection lost")]
438 ConnectionLost(#[from] ConnectionError),
439 #[error("closed stream")]
441 ClosedStream,
442 #[error("ordered read after unordered read")]
448 IllegalOrderedRead,
449 #[error("0-RTT rejected")]
456 ZeroRttRejected,
457}
458
459impl From<ReadableError> for ReadError {
460 fn from(e: ReadableError) -> Self {
461 match e {
462 ReadableError::ClosedStream => Self::ClosedStream,
463 ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
464 }
465 }
466}
467
468impl From<ResetError> for ReadError {
469 fn from(e: ResetError) -> Self {
470 match e {
471 ResetError::ConnectionLost(e) => Self::ConnectionLost(e),
472 ResetError::ZeroRttRejected => Self::ZeroRttRejected,
473 }
474 }
475}
476
477impl From<ReadError> for io::Error {
478 fn from(x: ReadError) -> Self {
479 use self::ReadError::*;
480 let kind = match x {
481 Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
482 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
483 IllegalOrderedRead => io::ErrorKind::InvalidInput,
484 };
485 Self::new(kind, x)
486 }
487}
488
489#[derive(Debug, Error, Clone, PartialEq, Eq)]
491pub enum ReadExactError {
492 #[error("stream finished early (expected {0} bytes more)")]
494 FinishedEarly(usize),
495 #[error(transparent)]
497 ReadError(#[from] ReadError),
498}
499
500#[derive(Debug, Error, Clone, PartialEq, Eq)]
502pub enum ResetError {
503 #[error("connection lost")]
505 ConnectionLost(#[from] ConnectionError),
506 #[error("0-RTT rejected")]
513 ZeroRttRejected,
514}
515
516impl From<ResetError> for io::Error {
517 fn from(x: ResetError) -> Self {
518 use ResetError::*;
519 let kind = match x {
520 ZeroRttRejected => io::ErrorKind::ConnectionReset,
521 ConnectionLost(_) => io::ErrorKind::NotConnected,
522 };
523 Self::new(kind, x)
524 }
525}
526
527impl AsyncRead for RecvStream {
528 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
529 let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
530 .await
531 .inspect(|&n| unsafe { buf.advance_to(n) })
532 .map_err(Into::into);
533 BufResult(res, buf)
534 }
535}
536
537#[cfg(feature = "io-compat")]
538mod compat {
539 use std::{
540 ops::{Deref, DerefMut},
541 pin::Pin,
542 task::ready,
543 };
544
545 use compio_buf::{IntoInner, bytes::BufMut};
546
547 use super::*;
548
549 pub struct CompatRecvStream(pub(super) RecvStream);
551
552 impl CompatRecvStream {
553 fn poll_read(
554 &mut self,
555 cx: &mut Context,
556 mut buf: impl BufMut,
557 ) -> Poll<Result<Option<usize>, ReadError>> {
558 self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
559 .map(|res| {
560 if let Ok(Some(n)) = &res {
561 unsafe { buf.advance_mut(*n) }
562 }
563 res
564 })
565 }
566
567 pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
574 poll_fn(|cx| self.poll_read(cx, &mut buf)).await
575 }
576
577 pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
583 poll_fn(|cx| {
584 while buf.has_remaining_mut() {
585 if ready!(self.poll_read(cx, &mut buf))?.is_none() {
586 return Poll::Ready(Err(ReadExactError::FinishedEarly(
587 buf.remaining_mut(),
588 )));
589 }
590 }
591 Poll::Ready(Ok(()))
592 })
593 .await
594 }
595 }
596
597 impl IntoInner for CompatRecvStream {
598 type Inner = RecvStream;
599
600 fn into_inner(self) -> Self::Inner {
601 self.0
602 }
603 }
604
605 impl Deref for CompatRecvStream {
606 type Target = RecvStream;
607
608 fn deref(&self) -> &Self::Target {
609 &self.0
610 }
611 }
612
613 impl DerefMut for CompatRecvStream {
614 fn deref_mut(&mut self) -> &mut Self::Target {
615 &mut self.0
616 }
617 }
618
619 impl futures_util::AsyncRead for CompatRecvStream {
620 fn poll_read(
621 self: Pin<&mut Self>,
622 cx: &mut Context<'_>,
623 buf: &mut [u8],
624 ) -> Poll<io::Result<usize>> {
625 self.get_mut()
627 .poll_read_uninit(cx, unsafe {
628 std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
629 })
630 .map_err(Into::into)
631 }
632 }
633}
634
635#[cfg(feature = "io-compat")]
636pub use compat::CompatRecvStream;
637
638#[cfg(feature = "h3")]
639pub(crate) mod h3_impl {
640 use h3::quic::{self, StreamErrorIncoming};
641
642 use super::*;
643
644 impl From<ReadError> for StreamErrorIncoming {
645 fn from(e: ReadError) -> Self {
646 use ReadError::*;
647 match e {
648 Reset(code) => Self::StreamTerminated {
649 error_code: code.into_inner(),
650 },
651 ConnectionLost(e) => Self::ConnectionErrorIncoming {
652 connection_error: e.into(),
653 },
654 IllegalOrderedRead => unreachable!("illegal ordered read"),
655 e => Self::Unknown(Box::new(e)),
656 }
657 }
658 }
659
660 impl quic::RecvStream for RecvStream {
661 type Buf = Bytes;
662
663 fn poll_data(
664 &mut self,
665 cx: &mut Context<'_>,
666 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
667 self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
668 Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
669 res => (None, res.err()).into(),
670 })
671 .map_err(Into::into)
672 }
673
674 fn stop_sending(&mut self, error_code: u64) {
675 self.stop(error_code.try_into().expect("invalid error_code"))
676 .ok();
677 }
678
679 fn recv_id(&self) -> quic::StreamId {
680 u64::from(self.stream).try_into().unwrap()
681 }
682 }
683}