compio_quic/
recv_stream.rs

1use std::{
2    io,
3    mem::MaybeUninit,
4    task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBufMut, bytes::Bytes};
8use compio_io::AsyncRead;
9use futures_util::future::poll_fn;
10use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, StoppedError, sync::shared::Shared};
14
15/// A stream that can only be used to receive data
16///
17/// `stop(0)` is implicitly called on drop unless:
18/// - A variant of [`ReadError`] has been yielded by a read call
19/// - [`stop()`] was called explicitly
20///
21/// # Cancellation
22///
23/// A `read` method is said to be *cancel-safe* when dropping its future before
24/// the future becomes ready cannot lead to loss of stream data. This is true of
25/// methods which succeed immediately when any progress is made, and is not true
26/// of methods which might need to perform multiple reads internally before
27/// succeeding. Each `read` method documents whether it is cancel-safe.
28///
29/// # Common issues
30///
31/// ## Data never received on a locally-opened stream
32///
33/// Peers are not notified of streams until they or a later-numbered stream are
34/// used to send data. If a bidirectional stream is locally opened but never
35/// used to send, then the peer may never see it. Application protocols should
36/// always arrange for the endpoint which will first transmit on a stream to be
37/// the endpoint responsible for opening it.
38///
39/// ## Data never received on a remotely-opened stream
40///
41/// Verify that the stream you are receiving is the same one that the server is
42/// sending on, e.g. by logging the [`id`] of each. Streams are always accepted
43/// in the same order as they are created, i.e. ascending order by [`StreamId`].
44/// For example, even if a sender first transmits on bidirectional stream 1, the
45/// first stream yielded by [`Connection::accept_bi`] on the receiver
46/// will be bidirectional stream 0.
47///
48/// [`stop()`]: RecvStream::stop
49/// [`id`]: RecvStream::id
50/// [`Connection::accept_bi`]: crate::Connection::accept_bi
51#[derive(Debug)]
52pub struct RecvStream {
53    conn: Shared<ConnectionInner>,
54    stream: StreamId,
55    is_0rtt: bool,
56    all_data_read: bool,
57    reset: Option<VarInt>,
58}
59
60impl RecvStream {
61    pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
62        Self {
63            conn,
64            stream,
65            is_0rtt,
66            all_data_read: false,
67            reset: None,
68        }
69    }
70
71    /// Get the identity of this stream
72    pub fn id(&self) -> StreamId {
73        self.stream
74    }
75
76    /// Check if this stream has been opened during 0-RTT.
77    ///
78    /// In which case any non-idempotent request should be considered dangerous
79    /// at the application level. Because read data is subject to replay
80    /// attacks.
81    pub fn is_0rtt(&self) -> bool {
82        self.is_0rtt
83    }
84
85    /// Stop accepting data
86    ///
87    /// Discards unread data and notifies the peer to stop transmitting. Once
88    /// stopped, further attempts to operate on a stream will yield
89    /// `ClosedStream` errors.
90    pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
91        let mut state = self.conn.state();
92        if self.is_0rtt && !state.check_0rtt() {
93            return Ok(());
94        }
95        state.conn.recv_stream(self.stream).stop(error_code)?;
96        state.wake();
97        self.all_data_read = true;
98        Ok(())
99    }
100
101    /// Completes when the stream has been reset by the peer or otherwise
102    /// closed.
103    ///
104    /// Yields `Some` with the reset error code when the stream is reset by the
105    /// peer. Yields `None` when the stream was previously
106    /// [`stop()`](Self::stop)ed, or when the stream was
107    /// [`finish()`](crate::SendStream::finish)ed by the peer and all data has
108    /// been received, after which it is no longer meaningful for the stream to
109    /// be reset.
110    ///
111    /// This operation is cancel-safe.
112    pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
113        poll_fn(|cx| {
114            let mut state = self.conn.state();
115
116            if self.is_0rtt && !state.check_0rtt() {
117                return Poll::Ready(Err(StoppedError::ZeroRttRejected));
118            }
119            if let Some(code) = self.reset {
120                return Poll::Ready(Ok(Some(code)));
121            }
122
123            match state.conn.recv_stream(self.stream).received_reset() {
124                Err(_) => Poll::Ready(Ok(None)),
125                Ok(Some(error_code)) => {
126                    // Stream state has just now been freed, so the connection may need to issue new
127                    // stream ID flow control credit
128                    state.wake();
129                    Poll::Ready(Ok(Some(error_code)))
130                }
131                Ok(None) => {
132                    if let Some(e) = &state.error {
133                        return Poll::Ready(Err(e.clone().into()));
134                    }
135                    // Resets always notify readers, since a reset is an immediate read error. We
136                    // could introduce a dedicated channel to reduce the risk of spurious wakeups,
137                    // but that increased complexity is probably not justified, as an application
138                    // that is expecting a reset is not likely to receive large amounts of data.
139                    state.readable.insert(self.stream, cx.waker().clone());
140                    Poll::Pending
141                }
142            }
143        })
144        .await
145    }
146
147    /// Handle common logic related to reading out of a receive stream.
148    ///
149    /// This takes an `FnMut` closure that takes care of the actual reading
150    /// process, matching the detailed read semantics for the calling
151    /// function with a particular return type. The closure can read from
152    /// the passed `&mut Chunks` and has to return the status after reading:
153    /// the amount of data read, and the status after the final read call.
154    fn execute_poll_read<F, T>(
155        &mut self,
156        cx: &mut Context,
157        ordered: bool,
158        mut read_fn: F,
159    ) -> Poll<Result<Option<T>, ReadError>>
160    where
161        F: FnMut(&mut Chunks) -> ReadStatus<T>,
162    {
163        use quinn_proto::ReadError::*;
164
165        if self.all_data_read {
166            return Poll::Ready(Ok(None));
167        }
168
169        let mut state = self.conn.state();
170        if self.is_0rtt && !state.check_0rtt() {
171            return Poll::Ready(Err(ReadError::ZeroRttRejected));
172        }
173
174        // If we stored an error during a previous call, return it now. This can happen
175        // if a `read_fn` both wants to return data and also returns an error in
176        // its final stream status.
177        let status = match self.reset {
178            Some(code) => ReadStatus::Failed(None, Reset(code)),
179            None => {
180                let mut recv = state.conn.recv_stream(self.stream);
181                let mut chunks = recv.read(ordered)?;
182                let status = read_fn(&mut chunks);
183                if chunks.finalize().should_transmit() {
184                    state.wake();
185                }
186                status
187            }
188        };
189
190        match status {
191            ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
192            ReadStatus::Finished(read) => {
193                self.all_data_read = true;
194                Poll::Ready(Ok(read))
195            }
196            ReadStatus::Failed(read, Blocked) => match read {
197                Some(val) => Poll::Ready(Ok(Some(val))),
198                None => {
199                    if let Some(error) = &state.error {
200                        return Poll::Ready(Err(error.clone().into()));
201                    }
202                    state.readable.insert(self.stream, cx.waker().clone());
203                    Poll::Pending
204                }
205            },
206            ReadStatus::Failed(read, Reset(error_code)) => match read {
207                None => {
208                    self.all_data_read = true;
209                    self.reset = Some(error_code);
210                    Poll::Ready(Err(ReadError::Reset(error_code)))
211                }
212                done => {
213                    self.reset = Some(error_code);
214                    Poll::Ready(Ok(done))
215                }
216            },
217        }
218    }
219
220    pub(crate) fn poll_read_impl(
221        &mut self,
222        cx: &mut Context,
223        buf: &mut [MaybeUninit<u8>],
224    ) -> Poll<Result<Option<usize>, ReadError>> {
225        if buf.is_empty() {
226            return Poll::Ready(Ok(Some(0)));
227        }
228
229        self.execute_poll_read(cx, true, |chunks| {
230            let mut read = 0;
231            loop {
232                if read >= buf.len() {
233                    // We know `read > 0` because `buf` cannot be empty here
234                    return ReadStatus::Readable(read);
235                }
236
237                match chunks.next(buf.len() - read) {
238                    Ok(Some(chunk)) => {
239                        let bytes = chunk.bytes;
240                        let len = bytes.len();
241                        buf[read..read + len].copy_from_slice(unsafe {
242                            std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
243                        });
244                        read += len;
245                    }
246                    res => {
247                        return (if read == 0 { None } else { Some(read) }, res.err()).into();
248                    }
249                }
250            }
251        })
252    }
253
254    /// Attempts to read from the stream into the provided buffer
255    ///
256    /// On success, returns `Poll::Ready(Ok(num_bytes_read))` and places data
257    /// into `buf`. If the buffer passed in has non-zero length and a 0 is
258    /// returned, that indicates that the remote side has [`finish`]ed the
259    /// stream and the local side has already read all bytes.
260    ///
261    /// If no data is available for reading, this returns `Poll::Pending` and
262    /// arranges for the current task (via `cx.waker()`) to be notified when
263    /// the stream becomes readable or is closed.
264    ///
265    /// [`finish`]: crate::SendStream::finish
266    pub fn poll_read_uninit(
267        &mut self,
268        cx: &mut Context,
269        buf: &mut [MaybeUninit<u8>],
270    ) -> Poll<Result<usize, ReadError>> {
271        self.poll_read_impl(cx, buf)
272            .map(|res| res.map(|n| n.unwrap_or_default()))
273    }
274
275    /// Read the next segment of data.
276    ///
277    /// Yields `None` if the stream was finished. Otherwise, yields a segment of
278    /// data and its offset in the stream. If `ordered` is `true`, the chunk's
279    /// offset will be immediately after the last data yielded by
280    /// [`read()`](Self::read) or [`read_chunk()`](Self::read_chunk). If
281    /// `ordered` is `false`, segments may be received in any order, and the
282    /// `Chunk`'s `offset` field can be used to determine ordering in the
283    /// caller. Unordered reads are less prone to head-of-line blocking within a
284    /// stream, but require the application to manage reassembling the original
285    /// data.
286    ///
287    /// Slightly more efficient than `read` due to not copying. Chunk boundaries
288    /// do not correspond to peer writes, and hence cannot be used as framing.
289    ///
290    /// This operation is cancel-safe.
291    pub async fn read_chunk(
292        &mut self,
293        max_length: usize,
294        ordered: bool,
295    ) -> Result<Option<Chunk>, ReadError> {
296        poll_fn(|cx| {
297            self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
298                Ok(Some(chunk)) => ReadStatus::Readable(chunk),
299                res => (None, res.err()).into(),
300            })
301        })
302        .await
303    }
304
305    /// Read the next segments of data.
306    ///
307    /// Fills `bufs` with the segments of data beginning immediately after the
308    /// last data yielded by `read` or `read_chunk`, or `None` if the stream was
309    /// finished.
310    ///
311    /// Slightly more efficient than `read` due to not copying. Chunk boundaries
312    /// do not correspond to peer writes, and hence cannot be used as framing.
313    ///
314    /// This operation is cancel-safe.
315    pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
316        if bufs.is_empty() {
317            return Ok(Some(0));
318        }
319
320        poll_fn(|cx| {
321            self.execute_poll_read(cx, true, |chunks| {
322                let mut read = 0;
323                loop {
324                    if read >= bufs.len() {
325                        // We know `read > 0` because `bufs` cannot be empty here
326                        return ReadStatus::Readable(read);
327                    }
328
329                    match chunks.next(usize::MAX) {
330                        Ok(Some(chunk)) => {
331                            bufs[read] = chunk.bytes;
332                            read += 1;
333                        }
334                        res => {
335                            return (if read == 0 { None } else { Some(read) }, res.err()).into();
336                        }
337                    }
338                }
339            })
340        })
341        .await
342    }
343
344    /// Convenience method to read all remaining data into a buffer.
345    ///
346    /// If unordered reads have already been made, the resulting buffer may have
347    /// gaps containing zeros.
348    ///
349    /// This operation is *not* cancel-safe.
350    pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
351        let mut start = u64::MAX;
352        let mut end = 0;
353        let mut chunks = vec![];
354        loop {
355            let chunk = match self.read_chunk(usize::MAX, false).await {
356                Ok(Some(chunk)) => chunk,
357                Ok(None) => break,
358                Err(e) => return BufResult(Err(e.into()), buf),
359            };
360            start = start.min(chunk.offset);
361            end = end.max(chunk.offset + chunk.bytes.len() as u64);
362            chunks.push((chunk.offset, chunk.bytes));
363        }
364        if start == u64::MAX || start >= end {
365            // no data read
366            return BufResult(Ok(0), buf);
367        }
368        let len = (end - start) as usize;
369        let cap = buf.buf_capacity();
370        let needed = len.saturating_sub(cap);
371        if needed > 0
372            && let Err(e) = buf.reserve(needed)
373        {
374            return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
375        }
376        let slice = &mut buf.as_uninit()[..len];
377        slice.fill(MaybeUninit::new(0));
378        for (offset, bytes) in chunks {
379            let offset = (offset - start) as usize;
380            let buf_len = bytes.len();
381            slice[offset..offset + buf_len].copy_from_slice(unsafe {
382                std::slice::from_raw_parts(bytes.as_ptr().cast(), buf_len)
383            });
384        }
385        unsafe { buf.advance_to(len) }
386        BufResult(Ok(len), buf)
387    }
388
389    /// Convert into an [`futures_util`] compatible stream.
390    #[cfg(feature = "io-compat")]
391    pub fn into_compat(self) -> CompatRecvStream {
392        CompatRecvStream(self)
393    }
394}
395
396impl Drop for RecvStream {
397    fn drop(&mut self) {
398        let mut state = self.conn.state();
399
400        // clean up any previously registered wakers
401        state.readable.remove(&self.stream);
402
403        if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
404            return;
405        }
406        if !self.all_data_read {
407            // Ignore ClosedStream errors
408            let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
409            state.wake();
410        }
411    }
412}
413
414enum ReadStatus<T> {
415    Readable(T),
416    Finished(Option<T>),
417    Failed(Option<T>, quinn_proto::ReadError),
418}
419
420impl<T> From<(Option<T>, Option<quinn_proto::ReadError>)> for ReadStatus<T> {
421    fn from(status: (Option<T>, Option<quinn_proto::ReadError>)) -> Self {
422        match status {
423            (read, None) => Self::Finished(read),
424            (read, Some(e)) => Self::Failed(read, e),
425        }
426    }
427}
428
429/// Errors that arise from reading from a stream.
430#[derive(Debug, Error, Clone, PartialEq, Eq)]
431pub enum ReadError {
432    /// The peer abandoned transmitting data on this stream.
433    ///
434    /// Carries an application-defined error code.
435    #[error("stream reset by peer: error {0}")]
436    Reset(VarInt),
437    /// The connection was lost.
438    #[error("connection lost")]
439    ConnectionLost(#[from] ConnectionError),
440    /// The stream has already been stopped, finished, or reset.
441    #[error("closed stream")]
442    ClosedStream,
443    /// Attempted an ordered read following an unordered read.
444    ///
445    /// Performing an unordered read allows discontinuities to arise in the
446    /// receive buffer of a stream which cannot be recovered, making further
447    /// ordered reads impossible.
448    #[error("ordered read after unordered read")]
449    IllegalOrderedRead,
450    /// This was a 0-RTT stream and the server rejected it.
451    ///
452    /// Can only occur on clients for 0-RTT streams, which can be opened using
453    /// [`Connecting::into_0rtt()`].
454    ///
455    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
456    #[error("0-RTT rejected")]
457    ZeroRttRejected,
458}
459
460impl From<ReadableError> for ReadError {
461    fn from(e: ReadableError) -> Self {
462        match e {
463            ReadableError::ClosedStream => Self::ClosedStream,
464            ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
465        }
466    }
467}
468
469impl From<StoppedError> for ReadError {
470    fn from(e: StoppedError) -> Self {
471        match e {
472            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
473            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
474        }
475    }
476}
477
478impl From<ReadError> for io::Error {
479    fn from(x: ReadError) -> Self {
480        use self::ReadError::*;
481        let kind = match x {
482            Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
483            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
484            IllegalOrderedRead => io::ErrorKind::InvalidInput,
485        };
486        Self::new(kind, x)
487    }
488}
489
490/// Errors that arise from reading from a stream.
491#[derive(Debug, Error, Clone, PartialEq, Eq)]
492pub enum ReadExactError {
493    /// The stream finished before all bytes were read
494    #[error("stream finished early (expected {0} bytes more)")]
495    FinishedEarly(usize),
496    /// A read error occurred
497    #[error(transparent)]
498    ReadError(#[from] ReadError),
499}
500
501impl AsyncRead for RecvStream {
502    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
503        let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
504            .await
505            .inspect(|&n| unsafe { buf.advance_to(n) })
506            .map_err(Into::into);
507        BufResult(res, buf)
508    }
509}
510
511#[cfg(feature = "io-compat")]
512mod compat {
513    use std::{
514        ops::{Deref, DerefMut},
515        pin::Pin,
516        task::ready,
517    };
518
519    use compio_buf::{IntoInner, bytes::BufMut};
520
521    use super::*;
522
523    /// A [`futures_util`] compatible receive stream.
524    pub struct CompatRecvStream(pub(super) RecvStream);
525
526    impl CompatRecvStream {
527        fn poll_read(
528            &mut self,
529            cx: &mut Context,
530            mut buf: impl BufMut,
531        ) -> Poll<Result<Option<usize>, ReadError>> {
532            self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
533                .map(|res| {
534                    if let Ok(Some(n)) = &res {
535                        unsafe { buf.advance_mut(*n) }
536                    }
537                    res
538                })
539        }
540
541        /// Read data contiguously from the stream.
542        ///
543        /// Yields the number of bytes read into `buf` on success, or `None` if
544        /// the stream was finished.
545        ///
546        /// This operation is cancel-safe.
547        pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
548            poll_fn(|cx| self.poll_read(cx, &mut buf)).await
549        }
550
551        /// Read an exact number of bytes contiguously from the stream.
552        ///
553        /// See [`read()`] for details. This operation is *not* cancel-safe.
554        ///
555        /// [`read()`]: CompatRecvStream::read
556        pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
557            poll_fn(|cx| {
558                while buf.has_remaining_mut() {
559                    if ready!(self.poll_read(cx, &mut buf))?.is_none() {
560                        return Poll::Ready(Err(ReadExactError::FinishedEarly(
561                            buf.remaining_mut(),
562                        )));
563                    }
564                }
565                Poll::Ready(Ok(()))
566            })
567            .await
568        }
569    }
570
571    impl IntoInner for CompatRecvStream {
572        type Inner = RecvStream;
573
574        fn into_inner(self) -> Self::Inner {
575            self.0
576        }
577    }
578
579    impl Deref for CompatRecvStream {
580        type Target = RecvStream;
581
582        fn deref(&self) -> &Self::Target {
583            &self.0
584        }
585    }
586
587    impl DerefMut for CompatRecvStream {
588        fn deref_mut(&mut self) -> &mut Self::Target {
589            &mut self.0
590        }
591    }
592
593    impl futures_util::AsyncRead for CompatRecvStream {
594        fn poll_read(
595            self: Pin<&mut Self>,
596            cx: &mut Context<'_>,
597            buf: &mut [u8],
598        ) -> Poll<io::Result<usize>> {
599            // SAFETY: buf is valid
600            self.get_mut()
601                .poll_read_uninit(cx, unsafe {
602                    std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
603                })
604                .map_err(Into::into)
605        }
606    }
607}
608
609#[cfg(feature = "io-compat")]
610pub use compat::CompatRecvStream;
611
612#[cfg(feature = "h3")]
613pub(crate) mod h3_impl {
614    use h3::quic::{self, StreamErrorIncoming};
615
616    use super::*;
617
618    impl From<ReadError> for StreamErrorIncoming {
619        fn from(e: ReadError) -> Self {
620            use ReadError::*;
621            match e {
622                Reset(code) => Self::StreamTerminated {
623                    error_code: code.into_inner(),
624                },
625                ConnectionLost(e) => Self::ConnectionErrorIncoming {
626                    connection_error: e.into(),
627                },
628                IllegalOrderedRead => unreachable!("illegal ordered read"),
629                e => Self::Unknown(Box::new(e)),
630            }
631        }
632    }
633
634    impl quic::RecvStream for RecvStream {
635        type Buf = Bytes;
636
637        fn poll_data(
638            &mut self,
639            cx: &mut Context<'_>,
640        ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
641            self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
642                Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
643                res => (None, res.err()).into(),
644            })
645            .map_err(Into::into)
646        }
647
648        fn stop_sending(&mut self, error_code: u64) {
649            self.stop(error_code.try_into().expect("invalid error_code"))
650                .ok();
651        }
652
653        fn recv_id(&self) -> quic::StreamId {
654            u64::from(self.stream).try_into().unwrap()
655        }
656    }
657}