compio_quic/
send_stream.rs

1use std::{
2    io,
3    task::{Context, Poll},
4};
5
6use compio_buf::{BufResult, IoBuf, bytes::Bytes};
7use compio_io::AsyncWrite;
8use futures_util::{future::poll_fn, ready};
9use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
10use thiserror::Error;
11
12use crate::{ConnectionError, ConnectionInner, StoppedError, sync::shared::Shared};
13
14/// A stream that can only be used to send data.
15///
16/// If dropped, streams that haven't been explicitly [`reset()`] will be
17/// implicitly [`finish()`]ed, continuing to (re)transmit previously written
18/// data until it has been fully acknowledged or the connection is closed.
19///
20/// # Cancellation
21///
22/// A `write` method is said to be *cancel-safe* when dropping its future before
23/// the future becomes ready will always result in no data being written to the
24/// stream. This is true of methods which succeed immediately when any progress
25/// is made, and is not true of methods which might need to perform multiple
26/// writes internally before succeeding. Each `write` method documents whether
27/// it is cancel-safe.
28///
29/// [`reset()`]: SendStream::reset
30/// [`finish()`]: SendStream::finish
31#[derive(Debug)]
32pub struct SendStream {
33    conn: Shared<ConnectionInner>,
34    stream: StreamId,
35    is_0rtt: bool,
36}
37
38impl SendStream {
39    pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
40        Self {
41            conn,
42            stream,
43            is_0rtt,
44        }
45    }
46
47    /// Get the identity of this stream
48    pub fn id(&self) -> StreamId {
49        self.stream
50    }
51
52    /// Notify the peer that no more data will ever be written to this stream.
53    ///
54    /// It is an error to write to a stream after `finish()`ing it. [`reset()`]
55    /// may still be called after `finish` to abandon transmission of any stream
56    /// data that might still be buffered.
57    ///
58    /// To wait for the peer to receive all buffered stream data, see
59    /// [`stopped()`].
60    ///
61    /// May fail if [`finish()`] or  [`reset()`] was previously called.This
62    /// error is harmless and serves only to indicate that the caller may have
63    /// incorrect assumptions about the stream's state.
64    ///
65    /// [`reset()`]: Self::reset
66    /// [`stopped()`]: Self::stopped
67    /// [`finish()`]: Self::finish
68    pub fn finish(&mut self) -> Result<(), ClosedStream> {
69        let mut state = self.conn.state();
70        match state.conn.send_stream(self.stream).finish() {
71            Ok(()) => {
72                state.wake();
73                Ok(())
74            }
75            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
76            // Harmless. If the application needs to know about stopped streams at this point,
77            // it should call `stopped`.
78            Err(FinishError::Stopped(_)) => Ok(()),
79        }
80    }
81
82    /// Close the stream immediately.
83    ///
84    /// No new data can be written after calling this method. Locally buffered
85    /// data is dropped, and previously transmitted data will no longer be
86    /// retransmitted if lost. If an attempt has already been made to finish
87    /// the stream, the peer may still receive all written data.
88    ///
89    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was
90    /// previously called. This error is harmless and serves only to
91    /// indicate that the caller may have incorrect assumptions about the
92    /// stream's state.
93    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
94        let mut state = self.conn.state();
95        if self.is_0rtt && !state.check_0rtt() {
96            return Ok(());
97        }
98        state.conn.send_stream(self.stream).reset(error_code)?;
99        state.wake();
100        Ok(())
101    }
102
103    /// Set the priority of the stream.
104    ///
105    /// Every stream has an initial priority of 0. Locally buffered data
106    /// from streams with higher priority will be transmitted before data
107    /// from streams with lower priority. Changing the priority of a stream
108    /// with pending data may only take effect after that data has been
109    /// transmitted. Using many different priority levels per connection may
110    /// have a negative impact on performance.
111    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
112        self.conn
113            .state()
114            .conn
115            .send_stream(self.stream)
116            .set_priority(priority)
117    }
118
119    /// Get the priority of the stream
120    pub fn priority(&self) -> Result<i32, ClosedStream> {
121        self.conn.state().conn.send_stream(self.stream).priority()
122    }
123
124    /// Completes when the peer stops the stream or reads the stream to
125    /// completion.
126    ///
127    /// Yields `Some` with the stop error code if the peer stops the stream.
128    /// Yields `None` if the local side [`finish()`](Self::finish)es the stream
129    /// and then the peer acknowledges receipt of all stream data (although not
130    /// necessarily the processing of it), after which the peer closing the
131    /// stream is no longer meaningful.
132    ///
133    /// For a variety of reasons, the peer may not send acknowledgements
134    /// immediately upon receiving data. As such, relying on `stopped` to
135    /// know when the peer has read a stream to completion may introduce
136    /// more latency than using an application-level response of some sort.
137    pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
138        poll_fn(|cx| {
139            let mut state = self.conn.state();
140            if self.is_0rtt && !state.check_0rtt() {
141                return Poll::Ready(Err(StoppedError::ZeroRttRejected));
142            }
143            match state.conn.send_stream(self.stream).stopped() {
144                Err(_) => Poll::Ready(Ok(None)),
145                Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
146                Ok(None) => {
147                    if let Some(e) = &state.error {
148                        return Poll::Ready(Err(e.clone().into()));
149                    }
150                    state.stopped.insert(self.stream, cx.waker().clone());
151                    Poll::Pending
152                }
153            }
154        })
155        .await
156    }
157
158    fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
159    where
160        F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
161    {
162        let mut state = self.conn.try_state()?;
163        if self.is_0rtt && !state.check_0rtt() {
164            return Poll::Ready(Err(WriteError::ZeroRttRejected));
165        }
166        match f(state.conn.send_stream(self.stream)) {
167            Ok(r) => {
168                state.wake();
169                Poll::Ready(Ok(r))
170            }
171            Err(e) => match e.try_into() {
172                Ok(e) => Poll::Ready(Err(e)),
173                Err(()) => {
174                    state.writable.insert(self.stream, cx.waker().clone());
175                    Poll::Pending
176                }
177            },
178        }
179    }
180
181    /// Write chunks to the stream.
182    ///
183    /// Yields the number of bytes and chunks written on success.
184    /// Congestion and flow control may cause this to be shorter than
185    /// `buf.len()`, indicating that only a prefix of `bufs` was written.
186    ///
187    /// This operation is cancel-safe.
188    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
189        poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
190    }
191
192    /// Convenience method to write an entire list of chunks to the stream.
193    ///
194    /// This operation is *not* cancel-safe.
195    pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
196        let mut chunks = 0;
197        poll_fn(|cx| {
198            loop {
199                if chunks == bufs.len() {
200                    return Poll::Ready(Ok(()));
201                }
202                let written = ready!(self.execute_poll_write(cx, |mut stream| {
203                    stream.write_chunks(&mut bufs[chunks..])
204                }))?;
205                chunks += written.chunks;
206            }
207        })
208        .await
209    }
210
211    /// Convert this stream into a [`futures_util`] compatible stream.
212    #[cfg(feature = "io-compat")]
213    pub fn into_compat(self) -> CompatSendStream {
214        CompatSendStream(self)
215    }
216}
217
218impl Drop for SendStream {
219    fn drop(&mut self) {
220        let mut state = self.conn.state();
221
222        // clean up any previously registered wakers
223        state.stopped.remove(&self.stream);
224        state.writable.remove(&self.stream);
225
226        if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
227            return;
228        }
229        match state.conn.send_stream(self.stream).finish() {
230            Ok(()) => state.wake(),
231            Err(FinishError::Stopped(reason)) => {
232                if state.conn.send_stream(self.stream).reset(reason).is_ok() {
233                    state.wake();
234                }
235            }
236            // Already finished or reset, which is fine.
237            Err(FinishError::ClosedStream) => {}
238        }
239    }
240}
241
242/// Errors that arise from writing to a stream
243#[derive(Debug, Error, Clone, PartialEq, Eq)]
244pub enum WriteError {
245    /// The peer is no longer accepting data on this stream
246    ///
247    /// Carries an application-defined error code.
248    #[error("sending stopped by peer: error {0}")]
249    Stopped(VarInt),
250    /// The connection was lost
251    #[error("connection lost")]
252    ConnectionLost(#[from] ConnectionError),
253    /// The stream has already been finished or reset
254    #[error("closed stream")]
255    ClosedStream,
256    /// This was a 0-RTT stream and the server rejected it
257    ///
258    /// Can only occur on clients for 0-RTT streams, which can be opened using
259    /// [`Connecting::into_0rtt()`].
260    ///
261    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
262    #[error("0-RTT rejected")]
263    ZeroRttRejected,
264    /// Error when the stream is not ready, because it is still sending
265    /// data from a previous call
266    #[cfg(feature = "h3")]
267    #[error("stream not ready")]
268    NotReady,
269}
270
271impl TryFrom<quinn_proto::WriteError> for WriteError {
272    type Error = ();
273
274    fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
275        use quinn_proto::WriteError::*;
276        match value {
277            Stopped(e) => Ok(Self::Stopped(e)),
278            ClosedStream => Ok(Self::ClosedStream),
279            Blocked => Err(()),
280        }
281    }
282}
283
284impl From<StoppedError> for WriteError {
285    fn from(x: StoppedError) -> Self {
286        match x {
287            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
288            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
289        }
290    }
291}
292
293impl From<WriteError> for io::Error {
294    fn from(x: WriteError) -> Self {
295        use WriteError::*;
296        let kind = match x {
297            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
298            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
299            #[cfg(feature = "h3")]
300            NotReady => io::ErrorKind::Other,
301        };
302        Self::new(kind, x)
303    }
304}
305
306impl AsyncWrite for SendStream {
307    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
308        let res =
309            poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_init())))
310                .await
311                .map_err(Into::into);
312        BufResult(res, buf)
313    }
314
315    async fn flush(&mut self) -> io::Result<()> {
316        Ok(())
317    }
318
319    async fn shutdown(&mut self) -> io::Result<()> {
320        self.finish()?;
321        Ok(())
322    }
323}
324
325#[cfg(feature = "io-compat")]
326mod compat {
327    use std::{
328        ops::{Deref, DerefMut},
329        pin::Pin,
330    };
331
332    use compio_buf::IntoInner;
333
334    use super::*;
335
336    /// A [`futures_util`] compatible send stream.
337    pub struct CompatSendStream(pub(super) SendStream);
338
339    impl CompatSendStream {
340        /// Write bytes to the stream.
341        ///
342        /// Yields the number of bytes written on success. Congestion and flow
343        /// control may cause this to be shorter than `buf.len()`, indicating
344        /// that only a prefix of `buf` was written.
345        ///
346        /// This operation is cancel-safe.
347        pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
348            poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
349        }
350
351        /// Convenience method to write an entire buffer to the stream.
352        ///
353        /// This operation is *not* cancel-safe.
354        pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
355            let mut count = 0;
356            poll_fn(|cx| {
357                loop {
358                    if count == buf.len() {
359                        return Poll::Ready(Ok(()));
360                    }
361                    let n = ready!(
362                        self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..]))
363                    )?;
364                    count += n;
365                }
366            })
367            .await
368        }
369    }
370
371    impl IntoInner for CompatSendStream {
372        type Inner = SendStream;
373
374        fn into_inner(self) -> Self::Inner {
375            self.0
376        }
377    }
378
379    impl Deref for CompatSendStream {
380        type Target = SendStream;
381
382        fn deref(&self) -> &Self::Target {
383            &self.0
384        }
385    }
386
387    impl DerefMut for CompatSendStream {
388        fn deref_mut(&mut self) -> &mut Self::Target {
389            &mut self.0
390        }
391    }
392
393    impl futures_util::AsyncWrite for CompatSendStream {
394        fn poll_write(
395            self: Pin<&mut Self>,
396            cx: &mut Context<'_>,
397            buf: &[u8],
398        ) -> Poll<io::Result<usize>> {
399            self.get_mut()
400                .execute_poll_write(cx, |mut stream| stream.write(buf))
401                .map_err(Into::into)
402        }
403
404        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
405            Poll::Ready(Ok(()))
406        }
407
408        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
409            self.get_mut().finish()?;
410            Poll::Ready(Ok(()))
411        }
412    }
413}
414
415#[cfg(feature = "io-compat")]
416pub use compat::CompatSendStream;
417
418#[cfg(feature = "h3")]
419pub(crate) mod h3_impl {
420    use compio_buf::bytes::Buf;
421    use h3::quic::{self, StreamErrorIncoming, WriteBuf};
422
423    use super::*;
424
425    impl From<WriteError> for StreamErrorIncoming {
426        fn from(e: WriteError) -> Self {
427            use WriteError::*;
428            match e {
429                Stopped(code) => Self::StreamTerminated {
430                    error_code: code.into_inner(),
431                },
432                ConnectionLost(e) => Self::ConnectionErrorIncoming {
433                    connection_error: e.into(),
434                },
435
436                e => Self::Unknown(Box::new(e)),
437            }
438        }
439    }
440
441    /// A wrapper around `SendStream` that implements `quic::SendStream` and
442    /// `quic::SendStreamUnframed`.
443    pub struct SendStream<B> {
444        inner: super::SendStream,
445        buf: Option<WriteBuf<B>>,
446    }
447
448    impl<B> SendStream<B> {
449        pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
450            Self {
451                inner: super::SendStream::new(conn, stream, is_0rtt),
452                buf: None,
453            }
454        }
455    }
456
457    impl<B> quic::SendStream<B> for SendStream<B>
458    where
459        B: Buf,
460    {
461        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
462            if let Some(data) = &mut self.buf {
463                while data.has_remaining() {
464                    let n = ready!(
465                        self.inner
466                            .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
467                    )?;
468                    data.advance(n);
469                }
470            }
471            self.buf = None;
472            Poll::Ready(Ok(()))
473        }
474
475        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
476            if self.buf.is_some() {
477                return Err(WriteError::NotReady.into());
478            }
479            self.buf = Some(data.into());
480            Ok(())
481        }
482
483        fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
484            Poll::Ready(
485                self.inner
486                    .finish()
487                    .map_err(|_| WriteError::ClosedStream.into()),
488            )
489        }
490
491        fn reset(&mut self, reset_code: u64) {
492            self.inner
493                .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
494                .ok();
495        }
496
497        fn send_id(&self) -> quic::StreamId {
498            u64::from(self.inner.stream).try_into().unwrap()
499        }
500    }
501
502    impl<B> quic::SendStreamUnframed<B> for SendStream<B>
503    where
504        B: Buf,
505    {
506        fn poll_send<D: Buf>(
507            &mut self,
508            cx: &mut Context<'_>,
509            buf: &mut D,
510        ) -> Poll<Result<usize, StreamErrorIncoming>> {
511            // This signifies a bug in implementation
512            debug_assert!(
513                self.buf.is_some(),
514                "poll_send called while send stream is not ready"
515            );
516
517            let n = ready!(
518                self.inner
519                    .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
520            )?;
521            buf.advance(n);
522            Poll::Ready(Ok(n))
523        }
524    }
525}