1use std::{
2 io,
3 mem::MaybeUninit,
4 task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IntoInner, IoBufMut, bytes::Bytes};
8use compio_io::AsyncRead;
9use futures_util::future::poll_fn;
10use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
14
15#[derive(Debug)]
52pub struct RecvStream {
53 conn: Shared<ConnectionInner>,
54 stream: StreamId,
55 is_0rtt: bool,
56 all_data_read: bool,
57 reset: Option<VarInt>,
58}
59
60impl RecvStream {
61 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
62 Self {
63 conn,
64 stream,
65 is_0rtt,
66 all_data_read: false,
67 reset: None,
68 }
69 }
70
71 pub fn id(&self) -> StreamId {
73 self.stream
74 }
75
76 pub fn is_0rtt(&self) -> bool {
82 self.is_0rtt
83 }
84
85 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
91 let mut state = self.conn.state();
92 if self.is_0rtt && !state.check_0rtt() {
93 return Ok(());
94 }
95 state.conn.recv_stream(self.stream).stop(error_code)?;
96 state.wake();
97 self.all_data_read = true;
98 Ok(())
99 }
100
101 pub async fn received_reset(&mut self) -> Result<Option<VarInt>, ResetError> {
112 poll_fn(|cx| {
113 let mut state = self.conn.state();
114
115 if self.is_0rtt && !state.check_0rtt() {
116 return Poll::Ready(Err(ResetError::ZeroRttRejected));
117 }
118 if let Some(code) = self.reset {
119 return Poll::Ready(Ok(Some(code)));
120 }
121
122 match state.conn.recv_stream(self.stream).received_reset() {
123 Err(_) => Poll::Ready(Ok(None)),
124 Ok(Some(error_code)) => {
125 state.wake();
128 Poll::Ready(Ok(Some(error_code)))
129 }
130 Ok(None) => {
131 if let Some(e) = &state.error {
132 return Poll::Ready(Err(e.clone().into()));
133 }
134 state.readable.insert(self.stream, cx.waker().clone());
139 Poll::Pending
140 }
141 }
142 })
143 .await
144 }
145
146 fn execute_poll_read<F, T>(
154 &mut self,
155 cx: &mut Context,
156 ordered: bool,
157 mut read_fn: F,
158 ) -> Poll<Result<Option<T>, ReadError>>
159 where
160 F: FnMut(&mut Chunks) -> ReadStatus<T>,
161 {
162 use quinn_proto::ReadError::*;
163
164 if self.all_data_read {
165 return Poll::Ready(Ok(None));
166 }
167
168 let mut state = self.conn.state();
169 if self.is_0rtt && !state.check_0rtt() {
170 return Poll::Ready(Err(ReadError::ZeroRttRejected));
171 }
172
173 let status = match self.reset {
177 Some(code) => ReadStatus::Failed(None, Reset(code)),
178 None => {
179 let mut recv = state.conn.recv_stream(self.stream);
180 let mut chunks = recv.read(ordered)?;
181 let status = read_fn(&mut chunks);
182 if chunks.finalize().should_transmit() {
183 state.wake();
184 }
185 status
186 }
187 };
188
189 match status {
190 ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
191 ReadStatus::Finished(read) => {
192 self.all_data_read = true;
193 Poll::Ready(Ok(read))
194 }
195 ReadStatus::Failed(read, Blocked) => match read {
196 Some(val) => Poll::Ready(Ok(Some(val))),
197 None => {
198 if let Some(error) = &state.error {
199 return Poll::Ready(Err(error.clone().into()));
200 }
201 state.readable.insert(self.stream, cx.waker().clone());
202 Poll::Pending
203 }
204 },
205 ReadStatus::Failed(read, Reset(error_code)) => match read {
206 None => {
207 self.all_data_read = true;
208 self.reset = Some(error_code);
209 Poll::Ready(Err(ReadError::Reset(error_code)))
210 }
211 done => {
212 self.reset = Some(error_code);
213 Poll::Ready(Ok(done))
214 }
215 },
216 }
217 }
218
219 pub(crate) fn poll_read_impl(
220 &mut self,
221 cx: &mut Context,
222 buf: &mut [MaybeUninit<u8>],
223 ) -> Poll<Result<Option<usize>, ReadError>> {
224 if buf.is_empty() {
225 return Poll::Ready(Ok(Some(0)));
226 }
227
228 self.execute_poll_read(cx, true, |chunks| {
229 let mut read = 0;
230 loop {
231 if read >= buf.len() {
232 return ReadStatus::Readable(read);
234 }
235
236 match chunks.next(buf.len() - read) {
237 Ok(Some(chunk)) => {
238 let bytes = chunk.bytes;
239 let len = bytes.len();
240 buf[read..read + len].copy_from_slice(unsafe {
241 std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
242 });
243 read += len;
244 }
245 res => {
246 return (if read == 0 { None } else { Some(read) }, res.err()).into();
247 }
248 }
249 }
250 })
251 }
252
253 pub fn poll_read_uninit(
266 &mut self,
267 cx: &mut Context,
268 buf: &mut [MaybeUninit<u8>],
269 ) -> Poll<Result<usize, ReadError>> {
270 self.poll_read_impl(cx, buf)
271 .map(|res| res.map(|n| n.unwrap_or_default()))
272 }
273
274 pub async fn read_chunk(
291 &mut self,
292 max_length: usize,
293 ordered: bool,
294 ) -> Result<Option<Chunk>, ReadError> {
295 poll_fn(|cx| {
296 self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
297 Ok(Some(chunk)) => ReadStatus::Readable(chunk),
298 res => (None, res.err()).into(),
299 })
300 })
301 .await
302 }
303
304 pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
315 if bufs.is_empty() {
316 return Ok(Some(0));
317 }
318
319 poll_fn(|cx| {
320 self.execute_poll_read(cx, true, |chunks| {
321 let mut read = 0;
322 loop {
323 if read >= bufs.len() {
324 return ReadStatus::Readable(read);
326 }
327
328 match chunks.next(usize::MAX) {
329 Ok(Some(chunk)) => {
330 bufs[read] = chunk.bytes;
331 read += 1;
332 }
333 res => {
334 return (if read == 0 { None } else { Some(read) }, res.err()).into();
335 }
336 }
337 }
338 })
339 })
340 .await
341 }
342
343 pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
350 let mut start = u64::MAX;
351 let mut end = 0;
352 let mut chunks = vec![];
353 loop {
354 let chunk = match self.read_chunk(usize::MAX, false).await {
355 Ok(Some(chunk)) => chunk,
356 Ok(None) => break,
357 Err(e) => return BufResult(Err(e.into()), buf),
358 };
359 start = start.min(chunk.offset);
360 end = end.max(chunk.offset + chunk.bytes.len() as u64);
361 chunks.push((chunk.offset, chunk.bytes));
362 }
363 if start == u64::MAX || start >= end {
364 return BufResult(Ok(0), buf);
366 }
367 let len = (end - start) as usize;
368 let cap = buf.buf_capacity();
369 let needed = len.saturating_sub(cap);
370 if needed > 0
371 && let Err(e) = buf.reserve(needed)
372 {
373 return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
374 }
375 let mut buf = buf.slice(..len);
376 let slice = buf.ensure_init();
377 for (offset, bytes) in chunks {
378 let offset = (offset - start) as usize;
379 let buf_len = bytes.len();
380 slice[offset..offset + buf_len].copy_from_slice(&bytes);
381 }
382 let mut buf = buf.into_inner();
383 unsafe { buf.advance_to(len) }
384 BufResult(Ok(len), buf)
385 }
386
387 #[cfg(feature = "io-compat")]
389 pub fn into_compat(self) -> CompatRecvStream {
390 CompatRecvStream(self)
391 }
392}
393
394impl Drop for RecvStream {
395 fn drop(&mut self) {
396 let mut state = self.conn.state();
397
398 state.readable.remove(&self.stream);
400
401 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
402 return;
403 }
404 if !self.all_data_read {
405 let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
407 state.wake();
408 }
409 }
410}
411
412enum ReadStatus<T> {
413 Readable(T),
414 Finished(Option<T>),
415 Failed(Option<T>, quinn_proto::ReadError),
416}
417
418impl<T> From<(Option<T>, Option<quinn_proto::ReadError>)> for ReadStatus<T> {
419 fn from(status: (Option<T>, Option<quinn_proto::ReadError>)) -> Self {
420 match status {
421 (read, None) => Self::Finished(read),
422 (read, Some(e)) => Self::Failed(read, e),
423 }
424 }
425}
426
427#[derive(Debug, Error, Clone, PartialEq, Eq)]
429pub enum ReadError {
430 #[error("stream reset by peer: error {0}")]
434 Reset(VarInt),
435 #[error("connection lost")]
437 ConnectionLost(#[from] ConnectionError),
438 #[error("closed stream")]
440 ClosedStream,
441 #[error("ordered read after unordered read")]
447 IllegalOrderedRead,
448 #[error("0-RTT rejected")]
455 ZeroRttRejected,
456}
457
458impl From<ReadableError> for ReadError {
459 fn from(e: ReadableError) -> Self {
460 match e {
461 ReadableError::ClosedStream => Self::ClosedStream,
462 ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
463 }
464 }
465}
466
467impl From<ResetError> for ReadError {
468 fn from(e: ResetError) -> Self {
469 match e {
470 ResetError::ConnectionLost(e) => Self::ConnectionLost(e),
471 ResetError::ZeroRttRejected => Self::ZeroRttRejected,
472 }
473 }
474}
475
476impl From<ReadError> for io::Error {
477 fn from(x: ReadError) -> Self {
478 use self::ReadError::*;
479 let kind = match x {
480 Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
481 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
482 IllegalOrderedRead => io::ErrorKind::InvalidInput,
483 };
484 Self::new(kind, x)
485 }
486}
487
488#[derive(Debug, Error, Clone, PartialEq, Eq)]
490pub enum ReadExactError {
491 #[error("stream finished early (expected {0} bytes more)")]
493 FinishedEarly(usize),
494 #[error(transparent)]
496 ReadError(#[from] ReadError),
497}
498
499#[derive(Debug, Error, Clone, PartialEq, Eq)]
501pub enum ResetError {
502 #[error("connection lost")]
504 ConnectionLost(#[from] ConnectionError),
505 #[error("0-RTT rejected")]
512 ZeroRttRejected,
513}
514
515impl From<ResetError> for io::Error {
516 fn from(x: ResetError) -> Self {
517 use ResetError::*;
518 let kind = match x {
519 ZeroRttRejected => io::ErrorKind::ConnectionReset,
520 ConnectionLost(_) => io::ErrorKind::NotConnected,
521 };
522 Self::new(kind, x)
523 }
524}
525
526impl AsyncRead for RecvStream {
527 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
528 let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
529 .await
530 .inspect(|&n| unsafe { buf.advance_to(n) })
531 .map_err(Into::into);
532 BufResult(res, buf)
533 }
534}
535
536#[cfg(feature = "io-compat")]
537mod compat {
538 use std::{
539 ops::{Deref, DerefMut},
540 pin::Pin,
541 task::ready,
542 };
543
544 use compio_buf::{IntoInner, bytes::BufMut};
545
546 use super::*;
547
548 pub struct CompatRecvStream(pub(super) RecvStream);
550
551 impl CompatRecvStream {
552 fn poll_read(
553 &mut self,
554 cx: &mut Context,
555 mut buf: impl BufMut,
556 ) -> Poll<Result<Option<usize>, ReadError>> {
557 self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
558 .map(|res| {
559 if let Ok(Some(n)) = &res {
560 unsafe { buf.advance_mut(*n) }
561 }
562 res
563 })
564 }
565
566 pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
573 poll_fn(|cx| self.poll_read(cx, &mut buf)).await
574 }
575
576 pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
582 poll_fn(|cx| {
583 while buf.has_remaining_mut() {
584 if ready!(self.poll_read(cx, &mut buf))?.is_none() {
585 return Poll::Ready(Err(ReadExactError::FinishedEarly(
586 buf.remaining_mut(),
587 )));
588 }
589 }
590 Poll::Ready(Ok(()))
591 })
592 .await
593 }
594 }
595
596 impl IntoInner for CompatRecvStream {
597 type Inner = RecvStream;
598
599 fn into_inner(self) -> Self::Inner {
600 self.0
601 }
602 }
603
604 impl Deref for CompatRecvStream {
605 type Target = RecvStream;
606
607 fn deref(&self) -> &Self::Target {
608 &self.0
609 }
610 }
611
612 impl DerefMut for CompatRecvStream {
613 fn deref_mut(&mut self) -> &mut Self::Target {
614 &mut self.0
615 }
616 }
617
618 impl futures_util::AsyncRead for CompatRecvStream {
619 fn poll_read(
620 self: Pin<&mut Self>,
621 cx: &mut Context<'_>,
622 buf: &mut [u8],
623 ) -> Poll<io::Result<usize>> {
624 self.get_mut()
626 .poll_read_uninit(cx, unsafe {
627 std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
628 })
629 .map_err(Into::into)
630 }
631 }
632}
633
634#[cfg(feature = "io-compat")]
635pub use compat::CompatRecvStream;
636
637#[cfg(feature = "h3")]
638pub(crate) mod h3_impl {
639 use h3::quic::{self, StreamErrorIncoming};
640
641 use super::*;
642
643 impl From<ReadError> for StreamErrorIncoming {
644 fn from(e: ReadError) -> Self {
645 use ReadError::*;
646 match e {
647 Reset(code) => Self::StreamTerminated {
648 error_code: code.into_inner(),
649 },
650 ConnectionLost(e) => Self::ConnectionErrorIncoming {
651 connection_error: e.into(),
652 },
653 IllegalOrderedRead => unreachable!("illegal ordered read"),
654 e => Self::Unknown(Box::new(e)),
655 }
656 }
657 }
658
659 impl quic::RecvStream for RecvStream {
660 type Buf = Bytes;
661
662 fn poll_data(
663 &mut self,
664 cx: &mut Context<'_>,
665 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
666 self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
667 Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
668 res => (None, res.err()).into(),
669 })
670 .map_err(Into::into)
671 }
672
673 fn stop_sending(&mut self, error_code: u64) {
674 self.stop(error_code.try_into().expect("invalid error_code"))
675 .ok();
676 }
677
678 fn recv_id(&self) -> quic::StreamId {
679 u64::from(self.stream).try_into().unwrap()
680 }
681 }
682}