Skip to main content

compio_quic/
recv_stream.rs

1use std::{
2    io,
3    mem::MaybeUninit,
4    task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBufMut, bytes::Bytes};
8use compio_io::AsyncRead;
9use futures_util::future::poll_fn;
10use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
14
15/// A stream that can only be used to receive data
16///
17/// `stop(0)` is implicitly called on drop unless:
18/// - A variant of [`ReadError`] has been yielded by a read call
19/// - [`stop()`] was called explicitly
20///
21/// # Cancellation
22///
23/// A `read` method is said to be *cancel-safe* when dropping its future before
24/// the future becomes ready cannot lead to loss of stream data. This is true of
25/// methods which succeed immediately when any progress is made, and is not true
26/// of methods which might need to perform multiple reads internally before
27/// succeeding. Each `read` method documents whether it is cancel-safe.
28///
29/// # Common issues
30///
31/// ## Data never received on a locally-opened stream
32///
33/// Peers are not notified of streams until they or a later-numbered stream are
34/// used to send data. If a bidirectional stream is locally opened but never
35/// used to send, then the peer may never see it. Application protocols should
36/// always arrange for the endpoint which will first transmit on a stream to be
37/// the endpoint responsible for opening it.
38///
39/// ## Data never received on a remotely-opened stream
40///
41/// Verify that the stream you are receiving is the same one that the server is
42/// sending on, e.g. by logging the [`id`] of each. Streams are always accepted
43/// in the same order as they are created, i.e. ascending order by [`StreamId`].
44/// For example, even if a sender first transmits on bidirectional stream 1, the
45/// first stream yielded by [`Connection::accept_bi`] on the receiver
46/// will be bidirectional stream 0.
47///
48/// [`stop()`]: RecvStream::stop
49/// [`id`]: RecvStream::id
50/// [`Connection::accept_bi`]: crate::Connection::accept_bi
51#[derive(Debug)]
52pub struct RecvStream {
53    conn: Shared<ConnectionInner>,
54    stream: StreamId,
55    is_0rtt: bool,
56    all_data_read: bool,
57    reset: Option<VarInt>,
58}
59
60impl RecvStream {
61    pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
62        Self {
63            conn,
64            stream,
65            is_0rtt,
66            all_data_read: false,
67            reset: None,
68        }
69    }
70
71    /// Get the identity of this stream
72    pub fn id(&self) -> StreamId {
73        self.stream
74    }
75
76    /// Check if this stream has been opened during 0-RTT.
77    ///
78    /// In which case any non-idempotent request should be considered dangerous
79    /// at the application level. Because read data is subject to replay
80    /// attacks.
81    pub fn is_0rtt(&self) -> bool {
82        self.is_0rtt
83    }
84
85    /// Stop accepting data
86    ///
87    /// Discards unread data and notifies the peer to stop transmitting. Once
88    /// stopped, further attempts to operate on a stream will yield
89    /// `ClosedStream` errors.
90    pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
91        let mut state = self.conn.state();
92        if self.is_0rtt && !state.check_0rtt() {
93            return Ok(());
94        }
95        state.conn.recv_stream(self.stream).stop(error_code)?;
96        state.wake();
97        self.all_data_read = true;
98        Ok(())
99    }
100
101    /// Completes when the stream has been reset by the peer or otherwise closed
102    ///
103    /// Yields `Some` with the reset error code when the stream is reset by the
104    /// peer. Yields `None` when the stream was previously
105    /// [`stop()`](Self::stop)ed, or when the stream was
106    /// [`finish()`](crate::SendStream::finish)ed by the peer and all data has
107    /// been received, after which it is no longer meaningful for the stream
108    /// to be reset.
109    ///
110    /// This operation is cancel-safe.
111    pub async fn received_reset(&mut self) -> Result<Option<VarInt>, ResetError> {
112        poll_fn(|cx| {
113            let mut state = self.conn.state();
114
115            if self.is_0rtt && !state.check_0rtt() {
116                return Poll::Ready(Err(ResetError::ZeroRttRejected));
117            }
118            if let Some(code) = self.reset {
119                return Poll::Ready(Ok(Some(code)));
120            }
121
122            match state.conn.recv_stream(self.stream).received_reset() {
123                Err(_) => Poll::Ready(Ok(None)),
124                Ok(Some(error_code)) => {
125                    // Stream state has just now been freed, so the connection may need to issue new
126                    // stream ID flow control credit
127                    state.wake();
128                    Poll::Ready(Ok(Some(error_code)))
129                }
130                Ok(None) => {
131                    if let Some(e) = &state.error {
132                        return Poll::Ready(Err(e.clone().into()));
133                    }
134                    // Resets always notify readers, since a reset is an immediate read error. We
135                    // could introduce a dedicated channel to reduce the risk of spurious wakeups,
136                    // but that increased complexity is probably not justified, as an application
137                    // that is expecting a reset is not likely to receive large amounts of data.
138                    state.readable.insert(self.stream, cx.waker().clone());
139                    Poll::Pending
140                }
141            }
142        })
143        .await
144    }
145
146    /// Handle common logic related to reading out of a receive stream.
147    ///
148    /// This takes an `FnMut` closure that takes care of the actual reading
149    /// process, matching the detailed read semantics for the calling
150    /// function with a particular return type. The closure can read from
151    /// the passed `&mut Chunks` and has to return the status after reading:
152    /// the amount of data read, and the status after the final read call.
153    fn execute_poll_read<F, T>(
154        &mut self,
155        cx: &mut Context,
156        ordered: bool,
157        mut read_fn: F,
158    ) -> Poll<Result<Option<T>, ReadError>>
159    where
160        F: FnMut(&mut Chunks) -> ReadStatus<T>,
161    {
162        use quinn_proto::ReadError::*;
163
164        if self.all_data_read {
165            return Poll::Ready(Ok(None));
166        }
167
168        let mut state = self.conn.state();
169        if self.is_0rtt && !state.check_0rtt() {
170            return Poll::Ready(Err(ReadError::ZeroRttRejected));
171        }
172
173        // If we stored an error during a previous call, return it now. This can happen
174        // if a `read_fn` both wants to return data and also returns an error in
175        // its final stream status.
176        let status = match self.reset {
177            Some(code) => ReadStatus::Failed(None, Reset(code)),
178            None => {
179                let mut recv = state.conn.recv_stream(self.stream);
180                let mut chunks = recv.read(ordered)?;
181                let status = read_fn(&mut chunks);
182                if chunks.finalize().should_transmit() {
183                    state.wake();
184                }
185                status
186            }
187        };
188
189        match status {
190            ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
191            ReadStatus::Finished(read) => {
192                self.all_data_read = true;
193                Poll::Ready(Ok(read))
194            }
195            ReadStatus::Failed(read, Blocked) => match read {
196                Some(val) => Poll::Ready(Ok(Some(val))),
197                None => {
198                    if let Some(error) = &state.error {
199                        return Poll::Ready(Err(error.clone().into()));
200                    }
201                    state.readable.insert(self.stream, cx.waker().clone());
202                    Poll::Pending
203                }
204            },
205            ReadStatus::Failed(read, Reset(error_code)) => match read {
206                None => {
207                    self.all_data_read = true;
208                    self.reset = Some(error_code);
209                    Poll::Ready(Err(ReadError::Reset(error_code)))
210                }
211                done => {
212                    self.reset = Some(error_code);
213                    Poll::Ready(Ok(done))
214                }
215            },
216        }
217    }
218
219    pub(crate) fn poll_read_impl(
220        &mut self,
221        cx: &mut Context,
222        buf: &mut [MaybeUninit<u8>],
223    ) -> Poll<Result<Option<usize>, ReadError>> {
224        if buf.is_empty() {
225            return Poll::Ready(Ok(Some(0)));
226        }
227
228        self.execute_poll_read(cx, true, |chunks| {
229            let mut read = 0;
230            loop {
231                if read >= buf.len() {
232                    // We know `read > 0` because `buf` cannot be empty here
233                    return ReadStatus::Readable(read);
234                }
235
236                match chunks.next(buf.len() - read) {
237                    Ok(Some(chunk)) => {
238                        let bytes = chunk.bytes;
239                        let len = bytes.len();
240                        buf[read..read + len].copy_from_slice(unsafe {
241                            std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
242                        });
243                        read += len;
244                    }
245                    res => {
246                        return (if read == 0 { None } else { Some(read) }, res.err()).into();
247                    }
248                }
249            }
250        })
251    }
252
253    /// Attempts to read from the stream into the provided buffer
254    ///
255    /// On success, returns `Poll::Ready(Ok(num_bytes_read))` and places data
256    /// into `buf`. If the buffer passed in has non-zero length and a 0 is
257    /// returned, that indicates that the remote side has [`finish`]ed the
258    /// stream and the local side has already read all bytes.
259    ///
260    /// If no data is available for reading, this returns `Poll::Pending` and
261    /// arranges for the current task (via `cx.waker()`) to be notified when
262    /// the stream becomes readable or is closed.
263    ///
264    /// [`finish`]: crate::SendStream::finish
265    pub fn poll_read_uninit(
266        &mut self,
267        cx: &mut Context,
268        buf: &mut [MaybeUninit<u8>],
269    ) -> Poll<Result<usize, ReadError>> {
270        self.poll_read_impl(cx, buf)
271            .map(|res| res.map(|n| n.unwrap_or_default()))
272    }
273
274    /// Read the next segment of data.
275    ///
276    /// Yields `None` if the stream was finished. Otherwise, yields a segment of
277    /// data and its offset in the stream. If `ordered` is `true`, the chunk's
278    /// offset will be immediately after the last data yielded by
279    /// [`read()`](Self::read) or [`read_chunk()`](Self::read_chunk). If
280    /// `ordered` is `false`, segments may be received in any order, and the
281    /// `Chunk`'s `offset` field can be used to determine ordering in the
282    /// caller. Unordered reads are less prone to head-of-line blocking within a
283    /// stream, but require the application to manage reassembling the original
284    /// data.
285    ///
286    /// Slightly more efficient than `read` due to not copying. Chunk boundaries
287    /// do not correspond to peer writes, and hence cannot be used as framing.
288    ///
289    /// This operation is cancel-safe.
290    pub async fn read_chunk(
291        &mut self,
292        max_length: usize,
293        ordered: bool,
294    ) -> Result<Option<Chunk>, ReadError> {
295        poll_fn(|cx| {
296            self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
297                Ok(Some(chunk)) => ReadStatus::Readable(chunk),
298                res => (None, res.err()).into(),
299            })
300        })
301        .await
302    }
303
304    /// Read the next segments of data.
305    ///
306    /// Fills `bufs` with the segments of data beginning immediately after the
307    /// last data yielded by `read` or `read_chunk`, or `None` if the stream was
308    /// finished.
309    ///
310    /// Slightly more efficient than `read` due to not copying. Chunk boundaries
311    /// do not correspond to peer writes, and hence cannot be used as framing.
312    ///
313    /// This operation is cancel-safe.
314    pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
315        if bufs.is_empty() {
316            return Ok(Some(0));
317        }
318
319        poll_fn(|cx| {
320            self.execute_poll_read(cx, true, |chunks| {
321                let mut read = 0;
322                loop {
323                    if read >= bufs.len() {
324                        // We know `read > 0` because `bufs` cannot be empty here
325                        return ReadStatus::Readable(read);
326                    }
327
328                    match chunks.next(usize::MAX) {
329                        Ok(Some(chunk)) => {
330                            bufs[read] = chunk.bytes;
331                            read += 1;
332                        }
333                        res => {
334                            return (if read == 0 { None } else { Some(read) }, res.err()).into();
335                        }
336                    }
337                }
338            })
339        })
340        .await
341    }
342
343    /// Convenience method to read all remaining data into a buffer.
344    ///
345    /// If unordered reads have already been made, the resulting buffer may have
346    /// gaps containing zeros.
347    ///
348    /// This operation is *not* cancel-safe.
349    pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
350        let mut start = u64::MAX;
351        let mut end = 0;
352        let mut chunks = vec![];
353        loop {
354            let chunk = match self.read_chunk(usize::MAX, false).await {
355                Ok(Some(chunk)) => chunk,
356                Ok(None) => break,
357                Err(e) => return BufResult(Err(e.into()), buf),
358            };
359            start = start.min(chunk.offset);
360            end = end.max(chunk.offset + chunk.bytes.len() as u64);
361            chunks.push((chunk.offset, chunk.bytes));
362        }
363        if start == u64::MAX || start >= end {
364            // no data read
365            return BufResult(Ok(0), buf);
366        }
367        let len = (end - start) as usize;
368        let cap = buf.buf_capacity();
369        let needed = len.saturating_sub(cap);
370        if needed > 0
371            && let Err(e) = buf.reserve(needed)
372        {
373            return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
374        }
375        let slice = &mut buf.as_uninit()[..len];
376        slice.fill(MaybeUninit::new(0));
377        for (offset, bytes) in chunks {
378            let offset = (offset - start) as usize;
379            let buf_len = bytes.len();
380            slice[offset..offset + buf_len].copy_from_slice(unsafe {
381                std::slice::from_raw_parts(bytes.as_ptr().cast(), buf_len)
382            });
383        }
384        unsafe { buf.advance_to(len) }
385        BufResult(Ok(len), buf)
386    }
387
388    /// Convert into an [`futures_util`] compatible stream.
389    #[cfg(feature = "io-compat")]
390    pub fn into_compat(self) -> CompatRecvStream {
391        CompatRecvStream(self)
392    }
393}
394
395impl Drop for RecvStream {
396    fn drop(&mut self) {
397        let mut state = self.conn.state();
398
399        // clean up any previously registered wakers
400        state.readable.remove(&self.stream);
401
402        if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
403            return;
404        }
405        if !self.all_data_read {
406            // Ignore ClosedStream errors
407            let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
408            state.wake();
409        }
410    }
411}
412
413enum ReadStatus<T> {
414    Readable(T),
415    Finished(Option<T>),
416    Failed(Option<T>, quinn_proto::ReadError),
417}
418
419impl<T> From<(Option<T>, Option<quinn_proto::ReadError>)> for ReadStatus<T> {
420    fn from(status: (Option<T>, Option<quinn_proto::ReadError>)) -> Self {
421        match status {
422            (read, None) => Self::Finished(read),
423            (read, Some(e)) => Self::Failed(read, e),
424        }
425    }
426}
427
428/// Errors that arise from reading from a stream.
429#[derive(Debug, Error, Clone, PartialEq, Eq)]
430pub enum ReadError {
431    /// The peer abandoned transmitting data on this stream.
432    ///
433    /// Carries an application-defined error code.
434    #[error("stream reset by peer: error {0}")]
435    Reset(VarInt),
436    /// The connection was lost.
437    #[error("connection lost")]
438    ConnectionLost(#[from] ConnectionError),
439    /// The stream has already been stopped, finished, or reset.
440    #[error("closed stream")]
441    ClosedStream,
442    /// Attempted an ordered read following an unordered read.
443    ///
444    /// Performing an unordered read allows discontinuities to arise in the
445    /// receive buffer of a stream which cannot be recovered, making further
446    /// ordered reads impossible.
447    #[error("ordered read after unordered read")]
448    IllegalOrderedRead,
449    /// This was a 0-RTT stream and the server rejected it.
450    ///
451    /// Can only occur on clients for 0-RTT streams, which can be opened using
452    /// [`Connecting::into_0rtt()`].
453    ///
454    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
455    #[error("0-RTT rejected")]
456    ZeroRttRejected,
457}
458
459impl From<ReadableError> for ReadError {
460    fn from(e: ReadableError) -> Self {
461        match e {
462            ReadableError::ClosedStream => Self::ClosedStream,
463            ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
464        }
465    }
466}
467
468impl From<ResetError> for ReadError {
469    fn from(e: ResetError) -> Self {
470        match e {
471            ResetError::ConnectionLost(e) => Self::ConnectionLost(e),
472            ResetError::ZeroRttRejected => Self::ZeroRttRejected,
473        }
474    }
475}
476
477impl From<ReadError> for io::Error {
478    fn from(x: ReadError) -> Self {
479        use self::ReadError::*;
480        let kind = match x {
481            Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
482            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
483            IllegalOrderedRead => io::ErrorKind::InvalidInput,
484        };
485        Self::new(kind, x)
486    }
487}
488
489/// Errors that arise from reading from a stream.
490#[derive(Debug, Error, Clone, PartialEq, Eq)]
491pub enum ReadExactError {
492    /// The stream finished before all bytes were read
493    #[error("stream finished early (expected {0} bytes more)")]
494    FinishedEarly(usize),
495    /// A read error occurred
496    #[error(transparent)]
497    ReadError(#[from] ReadError),
498}
499
500/// Errors that arise while waiting for a stream to be reset
501#[derive(Debug, Error, Clone, PartialEq, Eq)]
502pub enum ResetError {
503    /// The connection was lost
504    #[error("connection lost")]
505    ConnectionLost(#[from] ConnectionError),
506    /// This was a 0-RTT stream and the server rejected it
507    ///
508    /// Can only occur on clients for 0-RTT streams, which can be opened using
509    /// [`Connecting::into_0rtt()`].
510    ///
511    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
512    #[error("0-RTT rejected")]
513    ZeroRttRejected,
514}
515
516impl From<ResetError> for io::Error {
517    fn from(x: ResetError) -> Self {
518        use ResetError::*;
519        let kind = match x {
520            ZeroRttRejected => io::ErrorKind::ConnectionReset,
521            ConnectionLost(_) => io::ErrorKind::NotConnected,
522        };
523        Self::new(kind, x)
524    }
525}
526
527impl AsyncRead for RecvStream {
528    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
529        let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
530            .await
531            .inspect(|&n| unsafe { buf.advance_to(n) })
532            .map_err(Into::into);
533        BufResult(res, buf)
534    }
535}
536
537#[cfg(feature = "io-compat")]
538mod compat {
539    use std::{
540        ops::{Deref, DerefMut},
541        pin::Pin,
542        task::ready,
543    };
544
545    use compio_buf::{IntoInner, bytes::BufMut};
546
547    use super::*;
548
549    /// A [`futures_util`] compatible receive stream.
550    pub struct CompatRecvStream(pub(super) RecvStream);
551
552    impl CompatRecvStream {
553        fn poll_read(
554            &mut self,
555            cx: &mut Context,
556            mut buf: impl BufMut,
557        ) -> Poll<Result<Option<usize>, ReadError>> {
558            self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
559                .map(|res| {
560                    if let Ok(Some(n)) = &res {
561                        unsafe { buf.advance_mut(*n) }
562                    }
563                    res
564                })
565        }
566
567        /// Read data contiguously from the stream.
568        ///
569        /// Yields the number of bytes read into `buf` on success, or `None` if
570        /// the stream was finished.
571        ///
572        /// This operation is cancel-safe.
573        pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
574            poll_fn(|cx| self.poll_read(cx, &mut buf)).await
575        }
576
577        /// Read an exact number of bytes contiguously from the stream.
578        ///
579        /// See [`read()`] for details. This operation is *not* cancel-safe.
580        ///
581        /// [`read()`]: CompatRecvStream::read
582        pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
583            poll_fn(|cx| {
584                while buf.has_remaining_mut() {
585                    if ready!(self.poll_read(cx, &mut buf))?.is_none() {
586                        return Poll::Ready(Err(ReadExactError::FinishedEarly(
587                            buf.remaining_mut(),
588                        )));
589                    }
590                }
591                Poll::Ready(Ok(()))
592            })
593            .await
594        }
595    }
596
597    impl IntoInner for CompatRecvStream {
598        type Inner = RecvStream;
599
600        fn into_inner(self) -> Self::Inner {
601            self.0
602        }
603    }
604
605    impl Deref for CompatRecvStream {
606        type Target = RecvStream;
607
608        fn deref(&self) -> &Self::Target {
609            &self.0
610        }
611    }
612
613    impl DerefMut for CompatRecvStream {
614        fn deref_mut(&mut self) -> &mut Self::Target {
615            &mut self.0
616        }
617    }
618
619    impl futures_util::AsyncRead for CompatRecvStream {
620        fn poll_read(
621            self: Pin<&mut Self>,
622            cx: &mut Context<'_>,
623            buf: &mut [u8],
624        ) -> Poll<io::Result<usize>> {
625            // SAFETY: buf is valid
626            self.get_mut()
627                .poll_read_uninit(cx, unsafe {
628                    std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
629                })
630                .map_err(Into::into)
631        }
632    }
633}
634
635#[cfg(feature = "io-compat")]
636pub use compat::CompatRecvStream;
637
638#[cfg(feature = "h3")]
639pub(crate) mod h3_impl {
640    use h3::quic::{self, StreamErrorIncoming};
641
642    use super::*;
643
644    impl From<ReadError> for StreamErrorIncoming {
645        fn from(e: ReadError) -> Self {
646            use ReadError::*;
647            match e {
648                Reset(code) => Self::StreamTerminated {
649                    error_code: code.into_inner(),
650                },
651                ConnectionLost(e) => Self::ConnectionErrorIncoming {
652                    connection_error: e.into(),
653                },
654                IllegalOrderedRead => unreachable!("illegal ordered read"),
655                e => Self::Unknown(Box::new(e)),
656            }
657        }
658    }
659
660    impl quic::RecvStream for RecvStream {
661        type Buf = Bytes;
662
663        fn poll_data(
664            &mut self,
665            cx: &mut Context<'_>,
666        ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
667            self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
668                Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
669                res => (None, res.err()).into(),
670            })
671            .map_err(Into::into)
672        }
673
674        fn stop_sending(&mut self, error_code: u64) {
675            self.stop(error_code.try_into().expect("invalid error_code"))
676                .ok();
677        }
678
679        fn recv_id(&self) -> quic::StreamId {
680            u64::from(self.stream).try_into().unwrap()
681        }
682    }
683}