Skip to main content

compio_quic/
send_stream.rs

1use std::{
2    io,
3    task::{Context, Poll},
4};
5
6use compio_buf::{BufResult, IoBuf, bytes::Bytes};
7use compio_io::AsyncWrite;
8use futures_util::{future::poll_fn, ready};
9use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
10use thiserror::Error;
11
12use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
13
14/// A stream that can only be used to send data.
15///
16/// If dropped, streams that haven't been explicitly [`reset()`] will be
17/// implicitly [`finish()`]ed, continuing to (re)transmit previously written
18/// data until it has been fully acknowledged or the connection is closed.
19///
20/// # Cancellation
21///
22/// A `write` method is said to be *cancel-safe* when dropping its future before
23/// the future becomes ready will always result in no data being written to the
24/// stream. This is true of methods which succeed immediately when any progress
25/// is made, and is not true of methods which might need to perform multiple
26/// writes internally before succeeding. Each `write` method documents whether
27/// it is cancel-safe.
28///
29/// [`reset()`]: SendStream::reset
30/// [`finish()`]: SendStream::finish
31#[derive(Debug)]
32pub struct SendStream {
33    conn: Shared<ConnectionInner>,
34    stream: StreamId,
35    is_0rtt: bool,
36}
37
38impl SendStream {
39    pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
40        Self {
41            conn,
42            stream,
43            is_0rtt,
44        }
45    }
46
47    /// Get the identity of this stream
48    pub fn id(&self) -> StreamId {
49        self.stream
50    }
51
52    /// Notify the peer that no more data will ever be written to this stream.
53    ///
54    /// It is an error to write to a stream after `finish()`ing it. [`reset()`]
55    /// may still be called after `finish` to abandon transmission of any stream
56    /// data that might still be buffered.
57    ///
58    /// To wait for the peer to receive all buffered stream data, see
59    /// [`stopped()`].
60    ///
61    /// May fail if [`finish()`] or  [`reset()`] was previously called.This
62    /// error is harmless and serves only to indicate that the caller may have
63    /// incorrect assumptions about the stream's state.
64    ///
65    /// [`reset()`]: Self::reset
66    /// [`stopped()`]: Self::stopped
67    /// [`finish()`]: Self::finish
68    pub fn finish(&mut self) -> Result<(), ClosedStream> {
69        let mut state = self.conn.state();
70        match state.conn.send_stream(self.stream).finish() {
71            Ok(()) => {
72                state.wake();
73                Ok(())
74            }
75            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
76            // Harmless. If the application needs to know about stopped streams at this point,
77            // it should call `stopped`.
78            Err(FinishError::Stopped(_)) => Ok(()),
79        }
80    }
81
82    /// Close the stream immediately.
83    ///
84    /// No new data can be written after calling this method. Locally buffered
85    /// data is dropped, and previously transmitted data will no longer be
86    /// retransmitted if lost. If an attempt has already been made to finish
87    /// the stream, the peer may still receive all written data.
88    ///
89    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was
90    /// previously called. This error is harmless and serves only to
91    /// indicate that the caller may have incorrect assumptions about the
92    /// stream's state.
93    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
94        let mut state = self.conn.state();
95        if self.is_0rtt && !state.check_0rtt() {
96            return Ok(());
97        }
98        state.conn.send_stream(self.stream).reset(error_code)?;
99        state.wake();
100        Ok(())
101    }
102
103    /// Set the priority of the stream.
104    ///
105    /// Every stream has an initial priority of 0. Locally buffered data
106    /// from streams with higher priority will be transmitted before data
107    /// from streams with lower priority. Changing the priority of a stream
108    /// with pending data may only take effect after that data has been
109    /// transmitted. Using many different priority levels per connection may
110    /// have a negative impact on performance.
111    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
112        self.conn
113            .state()
114            .conn
115            .send_stream(self.stream)
116            .set_priority(priority)
117    }
118
119    /// Get the priority of the stream
120    pub fn priority(&self) -> Result<i32, ClosedStream> {
121        self.conn.state().conn.send_stream(self.stream).priority()
122    }
123
124    /// Completes when the peer stops the stream or reads the stream to
125    /// completion.
126    ///
127    /// Yields `Some` with the stop error code if the peer stops the stream.
128    /// Yields `None` if the local side [`finish()`](Self::finish)es the stream
129    /// and then the peer acknowledges receipt of all stream data (although not
130    /// necessarily the processing of it), after which the peer closing the
131    /// stream is no longer meaningful.
132    ///
133    /// For a variety of reasons, the peer may not send acknowledgements
134    /// immediately upon receiving data. As such, relying on `stopped` to
135    /// know when the peer has read a stream to completion may introduce
136    /// more latency than using an application-level response of some sort.
137    pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
138        poll_fn(|cx| {
139            let mut state = self.conn.state();
140            if self.is_0rtt && !state.check_0rtt() {
141                return Poll::Ready(Err(StoppedError::ZeroRttRejected));
142            }
143            match state.conn.send_stream(self.stream).stopped() {
144                Err(_) => Poll::Ready(Ok(None)),
145                Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
146                Ok(None) => {
147                    if let Some(e) = &state.error {
148                        return Poll::Ready(Err(e.clone().into()));
149                    }
150                    state.stopped.insert(self.stream, cx.waker().clone());
151                    Poll::Pending
152                }
153            }
154        })
155        .await
156    }
157
158    fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
159    where
160        F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
161    {
162        let mut state = self.conn.try_state()?;
163        if self.is_0rtt && !state.check_0rtt() {
164            return Poll::Ready(Err(WriteError::ZeroRttRejected));
165        }
166        match f(state.conn.send_stream(self.stream)) {
167            Ok(r) => {
168                state.wake();
169                Poll::Ready(Ok(r))
170            }
171            Err(e) => match e.try_into() {
172                Ok(e) => Poll::Ready(Err(e)),
173                Err(()) => {
174                    state.writable.insert(self.stream, cx.waker().clone());
175                    Poll::Pending
176                }
177            },
178        }
179    }
180
181    /// Write chunks to the stream.
182    ///
183    /// Yields the number of bytes and chunks written on success.
184    /// Congestion and flow control may cause this to be shorter than
185    /// `buf.len()`, indicating that only a prefix of `bufs` was written.
186    ///
187    /// This operation is cancel-safe.
188    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
189        poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
190    }
191
192    /// Convenience method to write an entire list of chunks to the stream.
193    ///
194    /// This operation is *not* cancel-safe.
195    pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
196        let mut chunks = 0;
197        poll_fn(|cx| {
198            loop {
199                if chunks == bufs.len() {
200                    return Poll::Ready(Ok(()));
201                }
202                let written = ready!(self.execute_poll_write(cx, |mut stream| {
203                    stream.write_chunks(&mut bufs[chunks..])
204                }))?;
205                chunks += written.chunks;
206            }
207        })
208        .await
209    }
210
211    /// Convert this stream into a [`futures_util`] compatible stream.
212    #[cfg(feature = "io-compat")]
213    pub fn into_compat(self) -> CompatSendStream {
214        CompatSendStream(self)
215    }
216}
217
218impl Drop for SendStream {
219    fn drop(&mut self) {
220        let mut state = self.conn.state();
221
222        // clean up any previously registered wakers
223        state.stopped.remove(&self.stream);
224        state.writable.remove(&self.stream);
225
226        if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
227            return;
228        }
229        match state.conn.send_stream(self.stream).finish() {
230            Ok(()) => state.wake(),
231            Err(FinishError::Stopped(reason)) => {
232                if state.conn.send_stream(self.stream).reset(reason).is_ok() {
233                    state.wake();
234                }
235            }
236            // Already finished or reset, which is fine.
237            Err(FinishError::ClosedStream) => {}
238        }
239    }
240}
241
242/// Errors that arise from writing to a stream
243#[derive(Debug, Error, Clone, PartialEq, Eq)]
244pub enum WriteError {
245    /// The peer is no longer accepting data on this stream
246    ///
247    /// Carries an application-defined error code.
248    #[error("sending stopped by peer: error {0}")]
249    Stopped(VarInt),
250    /// The connection was lost
251    #[error("connection lost")]
252    ConnectionLost(#[from] ConnectionError),
253    /// The stream has already been finished or reset
254    #[error("closed stream")]
255    ClosedStream,
256    /// This was a 0-RTT stream and the server rejected it
257    ///
258    /// Can only occur on clients for 0-RTT streams, which can be opened using
259    /// [`Connecting::into_0rtt()`].
260    ///
261    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
262    #[error("0-RTT rejected")]
263    ZeroRttRejected,
264    /// Error when the stream is not ready, because it is still sending
265    /// data from a previous call
266    #[cfg(feature = "h3")]
267    #[error("stream not ready")]
268    NotReady,
269}
270
271impl TryFrom<quinn_proto::WriteError> for WriteError {
272    type Error = ();
273
274    fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
275        use quinn_proto::WriteError::*;
276        match value {
277            Stopped(e) => Ok(Self::Stopped(e)),
278            ClosedStream => Ok(Self::ClosedStream),
279            Blocked => Err(()),
280        }
281    }
282}
283
284impl From<StoppedError> for WriteError {
285    fn from(x: StoppedError) -> Self {
286        match x {
287            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
288            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
289        }
290    }
291}
292
293impl From<WriteError> for io::Error {
294    fn from(x: WriteError) -> Self {
295        use WriteError::*;
296        let kind = match x {
297            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
298            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
299            #[cfg(feature = "h3")]
300            NotReady => io::ErrorKind::Other,
301        };
302        Self::new(kind, x)
303    }
304}
305
306/// Errors that arise while monitoring for a send stream stop from the peer
307#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
308pub enum StoppedError {
309    /// The connection was lost
310    #[error("connection lost")]
311    ConnectionLost(#[from] ConnectionError),
312    /// This was a 0-RTT stream and the server rejected it
313    ///
314    /// Can only occur on clients for 0-RTT streams, which can be opened using
315    /// [`Connecting::into_0rtt()`].
316    ///
317    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
318    #[error("0-RTT rejected")]
319    ZeroRttRejected,
320}
321
322impl From<StoppedError> for io::Error {
323    fn from(x: StoppedError) -> Self {
324        use StoppedError::*;
325        let kind = match x {
326            ZeroRttRejected => io::ErrorKind::ConnectionReset,
327            ConnectionLost(_) => io::ErrorKind::NotConnected,
328        };
329        Self::new(kind, x)
330    }
331}
332
333impl AsyncWrite for SendStream {
334    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
335        let res =
336            poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_init())))
337                .await
338                .map_err(Into::into);
339        BufResult(res, buf)
340    }
341
342    async fn flush(&mut self) -> io::Result<()> {
343        Ok(())
344    }
345
346    async fn shutdown(&mut self) -> io::Result<()> {
347        self.finish()?;
348        Ok(())
349    }
350}
351
352#[cfg(feature = "io-compat")]
353mod compat {
354    use std::{
355        ops::{Deref, DerefMut},
356        pin::Pin,
357    };
358
359    use compio_buf::IntoInner;
360
361    use super::*;
362
363    /// A [`futures_util`] compatible send stream.
364    pub struct CompatSendStream(pub(super) SendStream);
365
366    impl CompatSendStream {
367        /// Write bytes to the stream.
368        ///
369        /// Yields the number of bytes written on success. Congestion and flow
370        /// control may cause this to be shorter than `buf.len()`, indicating
371        /// that only a prefix of `buf` was written.
372        ///
373        /// This operation is cancel-safe.
374        pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
375            poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
376        }
377
378        /// Convenience method to write an entire buffer to the stream.
379        ///
380        /// This operation is *not* cancel-safe.
381        pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
382            let mut count = 0;
383            poll_fn(|cx| {
384                loop {
385                    if count == buf.len() {
386                        return Poll::Ready(Ok(()));
387                    }
388                    let n = ready!(
389                        self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..]))
390                    )?;
391                    count += n;
392                }
393            })
394            .await
395        }
396    }
397
398    impl IntoInner for CompatSendStream {
399        type Inner = SendStream;
400
401        fn into_inner(self) -> Self::Inner {
402            self.0
403        }
404    }
405
406    impl Deref for CompatSendStream {
407        type Target = SendStream;
408
409        fn deref(&self) -> &Self::Target {
410            &self.0
411        }
412    }
413
414    impl DerefMut for CompatSendStream {
415        fn deref_mut(&mut self) -> &mut Self::Target {
416            &mut self.0
417        }
418    }
419
420    impl futures_util::AsyncWrite for CompatSendStream {
421        fn poll_write(
422            self: Pin<&mut Self>,
423            cx: &mut Context<'_>,
424            buf: &[u8],
425        ) -> Poll<io::Result<usize>> {
426            self.get_mut()
427                .execute_poll_write(cx, |mut stream| stream.write(buf))
428                .map_err(Into::into)
429        }
430
431        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
432            Poll::Ready(Ok(()))
433        }
434
435        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436            self.get_mut().finish()?;
437            Poll::Ready(Ok(()))
438        }
439    }
440}
441
442#[cfg(feature = "io-compat")]
443pub use compat::CompatSendStream;
444
445#[cfg(feature = "h3")]
446pub(crate) mod h3_impl {
447    use compio_buf::bytes::Buf;
448    use h3::quic::{self, StreamErrorIncoming, WriteBuf};
449
450    use super::*;
451
452    impl From<WriteError> for StreamErrorIncoming {
453        fn from(e: WriteError) -> Self {
454            use WriteError::*;
455            match e {
456                Stopped(code) => Self::StreamTerminated {
457                    error_code: code.into_inner(),
458                },
459                ConnectionLost(e) => Self::ConnectionErrorIncoming {
460                    connection_error: e.into(),
461                },
462
463                e => Self::Unknown(Box::new(e)),
464            }
465        }
466    }
467
468    /// A wrapper around `SendStream` that implements `quic::SendStream` and
469    /// `quic::SendStreamUnframed`.
470    pub struct SendStream<B> {
471        inner: super::SendStream,
472        buf: Option<WriteBuf<B>>,
473    }
474
475    impl<B> SendStream<B> {
476        pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
477            Self {
478                inner: super::SendStream::new(conn, stream, is_0rtt),
479                buf: None,
480            }
481        }
482    }
483
484    impl<B> quic::SendStream<B> for SendStream<B>
485    where
486        B: Buf,
487    {
488        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
489            if let Some(data) = &mut self.buf {
490                while data.has_remaining() {
491                    let n = ready!(
492                        self.inner
493                            .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
494                    )?;
495                    data.advance(n);
496                }
497            }
498            self.buf = None;
499            Poll::Ready(Ok(()))
500        }
501
502        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
503            if self.buf.is_some() {
504                return Err(WriteError::NotReady.into());
505            }
506            self.buf = Some(data.into());
507            Ok(())
508        }
509
510        fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
511            Poll::Ready(
512                self.inner
513                    .finish()
514                    .map_err(|_| WriteError::ClosedStream.into()),
515            )
516        }
517
518        fn reset(&mut self, reset_code: u64) {
519            self.inner
520                .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
521                .ok();
522        }
523
524        fn send_id(&self) -> quic::StreamId {
525            u64::from(self.inner.stream).try_into().unwrap()
526        }
527    }
528
529    impl<B> quic::SendStreamUnframed<B> for SendStream<B>
530    where
531        B: Buf,
532    {
533        fn poll_send<D: Buf>(
534            &mut self,
535            cx: &mut Context<'_>,
536            buf: &mut D,
537        ) -> Poll<Result<usize, StreamErrorIncoming>> {
538            // This signifies a bug in implementation
539            debug_assert!(
540                self.buf.is_some(),
541                "poll_send called while send stream is not ready"
542            );
543
544            let n = ready!(
545                self.inner
546                    .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
547            )?;
548            buf.advance(n);
549            Poll::Ready(Ok(n))
550        }
551    }
552}