1use std::{
2 io,
3 mem::MaybeUninit,
4 task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBufMut, bytes::Bytes};
8use compio_io::AsyncRead;
9use futures_util::future::poll_fn;
10use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, StoppedError, sync::shared::Shared};
14
15#[derive(Debug)]
52pub struct RecvStream {
53 conn: Shared<ConnectionInner>,
54 stream: StreamId,
55 is_0rtt: bool,
56 all_data_read: bool,
57 reset: Option<VarInt>,
58}
59
60impl RecvStream {
61 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
62 Self {
63 conn,
64 stream,
65 is_0rtt,
66 all_data_read: false,
67 reset: None,
68 }
69 }
70
71 pub fn id(&self) -> StreamId {
73 self.stream
74 }
75
76 pub fn is_0rtt(&self) -> bool {
82 self.is_0rtt
83 }
84
85 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
91 let mut state = self.conn.state();
92 if self.is_0rtt && !state.check_0rtt() {
93 return Ok(());
94 }
95 state.conn.recv_stream(self.stream).stop(error_code)?;
96 state.wake();
97 self.all_data_read = true;
98 Ok(())
99 }
100
101 pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
113 poll_fn(|cx| {
114 let mut state = self.conn.state();
115
116 if self.is_0rtt && !state.check_0rtt() {
117 return Poll::Ready(Err(StoppedError::ZeroRttRejected));
118 }
119 if let Some(code) = self.reset {
120 return Poll::Ready(Ok(Some(code)));
121 }
122
123 match state.conn.recv_stream(self.stream).received_reset() {
124 Err(_) => Poll::Ready(Ok(None)),
125 Ok(Some(error_code)) => {
126 state.wake();
129 Poll::Ready(Ok(Some(error_code)))
130 }
131 Ok(None) => {
132 if let Some(e) = &state.error {
133 return Poll::Ready(Err(e.clone().into()));
134 }
135 state.readable.insert(self.stream, cx.waker().clone());
140 Poll::Pending
141 }
142 }
143 })
144 .await
145 }
146
147 fn execute_poll_read<F, T>(
155 &mut self,
156 cx: &mut Context,
157 ordered: bool,
158 mut read_fn: F,
159 ) -> Poll<Result<Option<T>, ReadError>>
160 where
161 F: FnMut(&mut Chunks) -> ReadStatus<T>,
162 {
163 use quinn_proto::ReadError::*;
164
165 if self.all_data_read {
166 return Poll::Ready(Ok(None));
167 }
168
169 let mut state = self.conn.state();
170 if self.is_0rtt && !state.check_0rtt() {
171 return Poll::Ready(Err(ReadError::ZeroRttRejected));
172 }
173
174 let status = match self.reset {
178 Some(code) => ReadStatus::Failed(None, Reset(code)),
179 None => {
180 let mut recv = state.conn.recv_stream(self.stream);
181 let mut chunks = recv.read(ordered)?;
182 let status = read_fn(&mut chunks);
183 if chunks.finalize().should_transmit() {
184 state.wake();
185 }
186 status
187 }
188 };
189
190 match status {
191 ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
192 ReadStatus::Finished(read) => {
193 self.all_data_read = true;
194 Poll::Ready(Ok(read))
195 }
196 ReadStatus::Failed(read, Blocked) => match read {
197 Some(val) => Poll::Ready(Ok(Some(val))),
198 None => {
199 if let Some(error) = &state.error {
200 return Poll::Ready(Err(error.clone().into()));
201 }
202 state.readable.insert(self.stream, cx.waker().clone());
203 Poll::Pending
204 }
205 },
206 ReadStatus::Failed(read, Reset(error_code)) => match read {
207 None => {
208 self.all_data_read = true;
209 self.reset = Some(error_code);
210 Poll::Ready(Err(ReadError::Reset(error_code)))
211 }
212 done => {
213 self.reset = Some(error_code);
214 Poll::Ready(Ok(done))
215 }
216 },
217 }
218 }
219
220 pub(crate) fn poll_read_impl(
221 &mut self,
222 cx: &mut Context,
223 buf: &mut [MaybeUninit<u8>],
224 ) -> Poll<Result<Option<usize>, ReadError>> {
225 if buf.is_empty() {
226 return Poll::Ready(Ok(Some(0)));
227 }
228
229 self.execute_poll_read(cx, true, |chunks| {
230 let mut read = 0;
231 loop {
232 if read >= buf.len() {
233 return ReadStatus::Readable(read);
235 }
236
237 match chunks.next(buf.len() - read) {
238 Ok(Some(chunk)) => {
239 let bytes = chunk.bytes;
240 let len = bytes.len();
241 buf[read..read + len].copy_from_slice(unsafe {
242 std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
243 });
244 read += len;
245 }
246 res => {
247 return (if read == 0 { None } else { Some(read) }, res.err()).into();
248 }
249 }
250 }
251 })
252 }
253
254 pub fn poll_read_uninit(
267 &mut self,
268 cx: &mut Context,
269 buf: &mut [MaybeUninit<u8>],
270 ) -> Poll<Result<usize, ReadError>> {
271 self.poll_read_impl(cx, buf)
272 .map(|res| res.map(|n| n.unwrap_or_default()))
273 }
274
275 pub async fn read_chunk(
292 &mut self,
293 max_length: usize,
294 ordered: bool,
295 ) -> Result<Option<Chunk>, ReadError> {
296 poll_fn(|cx| {
297 self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
298 Ok(Some(chunk)) => ReadStatus::Readable(chunk),
299 res => (None, res.err()).into(),
300 })
301 })
302 .await
303 }
304
305 pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
316 if bufs.is_empty() {
317 return Ok(Some(0));
318 }
319
320 poll_fn(|cx| {
321 self.execute_poll_read(cx, true, |chunks| {
322 let mut read = 0;
323 loop {
324 if read >= bufs.len() {
325 return ReadStatus::Readable(read);
327 }
328
329 match chunks.next(usize::MAX) {
330 Ok(Some(chunk)) => {
331 bufs[read] = chunk.bytes;
332 read += 1;
333 }
334 res => {
335 return (if read == 0 { None } else { Some(read) }, res.err()).into();
336 }
337 }
338 }
339 })
340 })
341 .await
342 }
343
344 pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
351 let mut start = u64::MAX;
352 let mut end = 0;
353 let mut chunks = vec![];
354 loop {
355 let chunk = match self.read_chunk(usize::MAX, false).await {
356 Ok(Some(chunk)) => chunk,
357 Ok(None) => break,
358 Err(e) => return BufResult(Err(e.into()), buf),
359 };
360 start = start.min(chunk.offset);
361 end = end.max(chunk.offset + chunk.bytes.len() as u64);
362 chunks.push((chunk.offset, chunk.bytes));
363 }
364 if start == u64::MAX || start >= end {
365 return BufResult(Ok(0), buf);
367 }
368 let len = (end - start) as usize;
369 let cap = buf.buf_capacity();
370 let needed = len.saturating_sub(cap);
371 if needed > 0
372 && let Err(e) = buf.reserve(needed)
373 {
374 return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
375 }
376 let slice = &mut buf.as_uninit()[..len];
377 slice.fill(MaybeUninit::new(0));
378 for (offset, bytes) in chunks {
379 let offset = (offset - start) as usize;
380 let buf_len = bytes.len();
381 slice[offset..offset + buf_len].copy_from_slice(unsafe {
382 std::slice::from_raw_parts(bytes.as_ptr().cast(), buf_len)
383 });
384 }
385 unsafe { buf.advance_to(len) }
386 BufResult(Ok(len), buf)
387 }
388
389 #[cfg(feature = "io-compat")]
391 pub fn into_compat(self) -> CompatRecvStream {
392 CompatRecvStream(self)
393 }
394}
395
396impl Drop for RecvStream {
397 fn drop(&mut self) {
398 let mut state = self.conn.state();
399
400 state.readable.remove(&self.stream);
402
403 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
404 return;
405 }
406 if !self.all_data_read {
407 let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
409 state.wake();
410 }
411 }
412}
413
414enum ReadStatus<T> {
415 Readable(T),
416 Finished(Option<T>),
417 Failed(Option<T>, quinn_proto::ReadError),
418}
419
420impl<T> From<(Option<T>, Option<quinn_proto::ReadError>)> for ReadStatus<T> {
421 fn from(status: (Option<T>, Option<quinn_proto::ReadError>)) -> Self {
422 match status {
423 (read, None) => Self::Finished(read),
424 (read, Some(e)) => Self::Failed(read, e),
425 }
426 }
427}
428
429#[derive(Debug, Error, Clone, PartialEq, Eq)]
431pub enum ReadError {
432 #[error("stream reset by peer: error {0}")]
436 Reset(VarInt),
437 #[error("connection lost")]
439 ConnectionLost(#[from] ConnectionError),
440 #[error("closed stream")]
442 ClosedStream,
443 #[error("ordered read after unordered read")]
449 IllegalOrderedRead,
450 #[error("0-RTT rejected")]
457 ZeroRttRejected,
458}
459
460impl From<ReadableError> for ReadError {
461 fn from(e: ReadableError) -> Self {
462 match e {
463 ReadableError::ClosedStream => Self::ClosedStream,
464 ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
465 }
466 }
467}
468
469impl From<StoppedError> for ReadError {
470 fn from(e: StoppedError) -> Self {
471 match e {
472 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
473 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
474 }
475 }
476}
477
478impl From<ReadError> for io::Error {
479 fn from(x: ReadError) -> Self {
480 use self::ReadError::*;
481 let kind = match x {
482 Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
483 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
484 IllegalOrderedRead => io::ErrorKind::InvalidInput,
485 };
486 Self::new(kind, x)
487 }
488}
489
490#[derive(Debug, Error, Clone, PartialEq, Eq)]
492pub enum ReadExactError {
493 #[error("stream finished early (expected {0} bytes more)")]
495 FinishedEarly(usize),
496 #[error(transparent)]
498 ReadError(#[from] ReadError),
499}
500
501impl AsyncRead for RecvStream {
502 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
503 let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
504 .await
505 .inspect(|&n| unsafe { buf.advance_to(n) })
506 .map_err(Into::into);
507 BufResult(res, buf)
508 }
509}
510
511#[cfg(feature = "io-compat")]
512mod compat {
513 use std::{
514 ops::{Deref, DerefMut},
515 pin::Pin,
516 task::ready,
517 };
518
519 use compio_buf::{IntoInner, bytes::BufMut};
520
521 use super::*;
522
523 pub struct CompatRecvStream(pub(super) RecvStream);
525
526 impl CompatRecvStream {
527 fn poll_read(
528 &mut self,
529 cx: &mut Context,
530 mut buf: impl BufMut,
531 ) -> Poll<Result<Option<usize>, ReadError>> {
532 self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
533 .map(|res| {
534 if let Ok(Some(n)) = &res {
535 unsafe { buf.advance_mut(*n) }
536 }
537 res
538 })
539 }
540
541 pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
548 poll_fn(|cx| self.poll_read(cx, &mut buf)).await
549 }
550
551 pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
557 poll_fn(|cx| {
558 while buf.has_remaining_mut() {
559 if ready!(self.poll_read(cx, &mut buf))?.is_none() {
560 return Poll::Ready(Err(ReadExactError::FinishedEarly(
561 buf.remaining_mut(),
562 )));
563 }
564 }
565 Poll::Ready(Ok(()))
566 })
567 .await
568 }
569 }
570
571 impl IntoInner for CompatRecvStream {
572 type Inner = RecvStream;
573
574 fn into_inner(self) -> Self::Inner {
575 self.0
576 }
577 }
578
579 impl Deref for CompatRecvStream {
580 type Target = RecvStream;
581
582 fn deref(&self) -> &Self::Target {
583 &self.0
584 }
585 }
586
587 impl DerefMut for CompatRecvStream {
588 fn deref_mut(&mut self) -> &mut Self::Target {
589 &mut self.0
590 }
591 }
592
593 impl futures_util::AsyncRead for CompatRecvStream {
594 fn poll_read(
595 self: Pin<&mut Self>,
596 cx: &mut Context<'_>,
597 buf: &mut [u8],
598 ) -> Poll<io::Result<usize>> {
599 self.get_mut()
601 .poll_read_uninit(cx, unsafe {
602 std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
603 })
604 .map_err(Into::into)
605 }
606 }
607}
608
609#[cfg(feature = "io-compat")]
610pub use compat::CompatRecvStream;
611
612#[cfg(feature = "h3")]
613pub(crate) mod h3_impl {
614 use h3::quic::{self, StreamErrorIncoming};
615
616 use super::*;
617
618 impl From<ReadError> for StreamErrorIncoming {
619 fn from(e: ReadError) -> Self {
620 use ReadError::*;
621 match e {
622 Reset(code) => Self::StreamTerminated {
623 error_code: code.into_inner(),
624 },
625 ConnectionLost(e) => Self::ConnectionErrorIncoming {
626 connection_error: e.into(),
627 },
628 IllegalOrderedRead => unreachable!("illegal ordered read"),
629 e => Self::Unknown(Box::new(e)),
630 }
631 }
632 }
633
634 impl quic::RecvStream for RecvStream {
635 type Buf = Bytes;
636
637 fn poll_data(
638 &mut self,
639 cx: &mut Context<'_>,
640 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
641 self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
642 Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
643 res => (None, res.err()).into(),
644 })
645 .map_err(Into::into)
646 }
647
648 fn stop_sending(&mut self, error_code: u64) {
649 self.stop(error_code.try_into().expect("invalid error_code"))
650 .ok();
651 }
652
653 fn recv_id(&self) -> quic::StreamId {
654 u64::from(self.stream).try_into().unwrap()
655 }
656 }
657}