compio_quic/
connection.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    net::{IpAddr, SocketAddr},
5    pin::{Pin, pin},
6    task::{Context, Poll, Waker},
7    time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15    Future, FutureExt, StreamExt,
16    future::{self, Fuse, FusedFuture, LocalBoxFuture},
17    select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22    ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt,
23    congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29    RecvStream, SendStream, Socket,
30    sync::{
31        mutex_blocking::{Mutex, MutexGuard},
32        shared::Shared,
33    },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38    Close(VarInt, Bytes),
39    Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44    pub(crate) conn: quinn_proto::Connection,
45    pub(crate) error: Option<ConnectionError>,
46    connected: bool,
47    worker: Option<JoinHandle<()>>,
48    poller: Option<Waker>,
49    on_connected: Option<Waker>,
50    on_handshake_data: Option<Waker>,
51    datagram_received: VecDeque<Waker>,
52    datagrams_unblocked: VecDeque<Waker>,
53    stream_opened: [VecDeque<Waker>; 2],
54    stream_available: [VecDeque<Waker>; 2],
55    pub(crate) writable: HashMap<StreamId, Waker>,
56    pub(crate) readable: HashMap<StreamId, Waker>,
57    pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61    fn terminate(&mut self, reason: ConnectionError) {
62        self.error = Some(reason);
63        self.connected = false;
64
65        if let Some(waker) = self.on_handshake_data.take() {
66            waker.wake()
67        }
68        if let Some(waker) = self.on_connected.take() {
69            waker.wake()
70        }
71        self.datagram_received.drain(..).for_each(Waker::wake);
72        self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73        for e in &mut self.stream_opened {
74            e.drain(..).for_each(Waker::wake);
75        }
76        for e in &mut self.stream_available {
77            e.drain(..).for_each(Waker::wake);
78        }
79        wake_all_streams(&mut self.writable);
80        wake_all_streams(&mut self.readable);
81        wake_all_streams(&mut self.stopped);
82    }
83
84    fn close(&mut self, error_code: VarInt, reason: Bytes) {
85        self.conn.close(Instant::now(), error_code, reason);
86        self.terminate(ConnectionError::LocallyClosed);
87        self.wake();
88    }
89
90    pub(crate) fn wake(&mut self) {
91        if let Some(waker) = self.poller.take() {
92            waker.wake()
93        }
94    }
95
96    #[cfg(rustls)]
97    fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98        self.conn
99            .crypto_session()
100            .handshake_data()
101            .map(|data| data.downcast::<HandshakeData>().unwrap())
102    }
103
104    pub(crate) fn check_0rtt(&self) -> bool {
105        self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106    }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110    if let Some(waker) = wakers.remove(&stream) {
111        waker.wake();
112    }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116    wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121    state: Mutex<ConnectionState>,
122    handle: ConnectionHandle,
123    socket: Socket,
124    events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125    events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129    if Shared::strong_count(this) == 2 {
130        this.state().close(0u32.into(), Bytes::new())
131    }
132}
133
134impl ConnectionInner {
135    fn new(
136        handle: ConnectionHandle,
137        conn: quinn_proto::Connection,
138        socket: Socket,
139        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140        events_rx: Receiver<ConnectionEvent>,
141    ) -> Self {
142        Self {
143            state: Mutex::new(ConnectionState {
144                conn,
145                connected: false,
146                error: None,
147                worker: None,
148                poller: None,
149                on_connected: None,
150                on_handshake_data: None,
151                datagram_received: VecDeque::new(),
152                datagrams_unblocked: VecDeque::new(),
153                stream_opened: [VecDeque::new(), VecDeque::new()],
154                stream_available: [VecDeque::new(), VecDeque::new()],
155                writable: HashMap::default(),
156                readable: HashMap::default(),
157                stopped: HashMap::default(),
158            }),
159            handle,
160            socket,
161            events_tx,
162            events_rx,
163        }
164    }
165
166    #[inline]
167    pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168        self.state.lock()
169    }
170
171    #[inline]
172    pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173        let state = self.state();
174        if let Some(error) = &state.error {
175            Err(error.clone())
176        } else {
177            Ok(state)
178        }
179    }
180
181    async fn run(&self) {
182        let mut poller = stream::poll_fn(|cx| {
183            let mut state = self.state();
184            let ready = state.poller.is_none();
185            match &state.poller {
186                Some(waker) if waker.will_wake(cx.waker()) => {}
187                _ => state.poller = Some(cx.waker().clone()),
188            };
189            if ready {
190                Poll::Ready(Some(()))
191            } else {
192                Poll::Pending
193            }
194        })
195        .fuse();
196
197        let mut timer = Timer::new();
198        let mut event_stream = self.events_rx.stream().ready_chunks(100);
199        let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200        let mut transmit_fut = pin!(Fuse::terminated());
201
202        loop {
203            let mut state = select! {
204                _ = poller.select_next_some() => self.state(),
205                _ = timer => {
206                    timer.reset(None);
207                    let mut state = self.state();
208                    state.conn.handle_timeout(Instant::now());
209                    state
210                }
211                events = event_stream.select_next_some() => {
212                    let mut state = self.state();
213                    for event in events {
214                        match event {
215                            ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216                            ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217                        }
218                    }
219                    state
220                },
221                buf = transmit_fut => {
222                    // The following line is required to avoid "type annotations needed" error
223                    let mut buf: Vec<_> = buf;
224                    buf.clear();
225                    send_buf = Some(buf);
226                    self.state()
227                },
228            };
229
230            if let Some(mut buf) = send_buf.take() {
231                if let Some(transmit) = state.conn.poll_transmit(
232                    Instant::now(),
233                    self.socket.max_gso_segments(),
234                    &mut buf,
235                ) {
236                    transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237                } else {
238                    send_buf = Some(buf);
239                }
240            }
241
242            timer.reset(state.conn.poll_timeout());
243
244            while let Some(event) = state.conn.poll_endpoint_events() {
245                let _ = self.events_tx.send((self.handle, event));
246            }
247
248            while let Some(event) = state.conn.poll() {
249                use quinn_proto::Event::*;
250                match event {
251                    HandshakeDataReady => {
252                        if let Some(waker) = state.on_handshake_data.take() {
253                            waker.wake()
254                        }
255                    }
256                    Connected => {
257                        state.connected = true;
258                        if let Some(waker) = state.on_connected.take() {
259                            waker.wake()
260                        }
261                        if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262                            // Wake up rejected 0-RTT streams so they can fail immediately with
263                            // `ZeroRttRejected` errors.
264                            wake_all_streams(&mut state.writable);
265                            wake_all_streams(&mut state.readable);
266                            wake_all_streams(&mut state.stopped);
267                        }
268                    }
269                    ConnectionLost { reason } => state.terminate(reason.into()),
270                    Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271                    Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272                    Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273                    Stream(StreamEvent::Stopped { id, .. }) => {
274                        wake_stream(id, &mut state.stopped);
275                        wake_stream(id, &mut state.writable);
276                    }
277                    Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278                        .drain(..)
279                        .for_each(Waker::wake),
280                    Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281                        .drain(..)
282                        .for_each(Waker::wake),
283                    DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284                    DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285                }
286            }
287
288            if state.conn.is_drained() {
289                break;
290            }
291        }
292    }
293}
294
295macro_rules! conn_fn {
296    () => {
297        /// The local IP address which was used when the peer established
298        /// the connection.
299        ///
300        /// This can be different from the address the endpoint is bound to, in case
301        /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
302        ///
303        /// This will return `None` for clients, or when the platform does not
304        /// expose this information.
305        pub fn local_ip(&self) -> Option<IpAddr> {
306            self.0.state().conn.local_ip()
307        }
308
309        /// The peer's UDP address.
310        ///
311        /// Will panic if called after `poll` has returned `Ready`.
312        pub fn remote_address(&self) -> SocketAddr {
313            self.0.state().conn.remote_address()
314        }
315
316        /// Current best estimate of this connection's latency (round-trip-time).
317        pub fn rtt(&self) -> Duration {
318            self.0.state().conn.rtt()
319        }
320
321        /// Connection statistics.
322        pub fn stats(&self) -> ConnectionStats {
323            self.0.state().conn.stats()
324        }
325
326        /// Current state of the congestion control algorithm. (For debugging
327        /// purposes)
328        pub fn congestion_state(&self) -> Box<dyn Controller> {
329            self.0.state().conn.congestion_state().clone_box()
330        }
331
332        /// Cryptographic identity of the peer.
333        pub fn peer_identity(
334            &self,
335        ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
336            self.0
337                .state()
338                .conn
339                .crypto_session()
340                .peer_identity()
341                .map(|v| v.downcast().unwrap())
342        }
343
344        /// Derive keying material from this connection's TLS session secrets.
345        ///
346        /// When both peers call this method with the same `label` and `context`
347        /// arguments and `output` buffers of equal length, they will get the
348        /// same sequence of bytes in `output`. These bytes are cryptographically
349        /// strong and pseudorandom, and are suitable for use as keying material.
350        ///
351        /// This function fails if called with an empty `output` or called prior to
352        /// the handshake completing.
353        ///
354        /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information.
355        pub fn export_keying_material(
356            &self,
357            output: &mut [u8],
358            label: &[u8],
359            context: &[u8],
360        ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
361            self.0
362                .state()
363                .conn
364                .crypto_session()
365                .export_keying_material(output, label, context)
366        }
367    };
368}
369
370/// In-progress connection attempt future
371#[derive(Debug)]
372#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
373pub struct Connecting(Shared<ConnectionInner>);
374
375impl Connecting {
376    conn_fn!();
377
378    pub(crate) fn new(
379        handle: ConnectionHandle,
380        conn: quinn_proto::Connection,
381        socket: Socket,
382        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
383        events_rx: Receiver<ConnectionEvent>,
384    ) -> Self {
385        let inner = Shared::new(ConnectionInner::new(
386            handle, conn, socket, events_tx, events_rx,
387        ));
388        let worker = compio_runtime::spawn({
389            let inner = inner.clone();
390            async move { inner.run().await }.in_current_span()
391        });
392        inner.state().worker = Some(worker);
393        Self(inner)
394    }
395
396    /// Parameters negotiated during the handshake.
397    #[cfg(rustls)]
398    pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
399        future::poll_fn(|cx| {
400            let mut state = self.0.try_state()?;
401            if let Some(data) = state.handshake_data() {
402                return Poll::Ready(Ok(data));
403            }
404
405            match &state.on_handshake_data {
406                Some(waker) if waker.will_wake(cx.waker()) => {}
407                _ => state.on_handshake_data = Some(cx.waker().clone()),
408            }
409
410            Poll::Pending
411        })
412        .await
413    }
414
415    /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened
416    /// security.
417    ///
418    /// Returns `Ok` immediately if the local endpoint is able to attempt
419    /// sending 0/0.5-RTT data. If so, the returned [`Connection`] can be used
420    /// to send application data without waiting for the rest of the handshake
421    /// to complete, at the cost of weakened cryptographic security guarantees.
422    /// The [`Connection::accepted_0rtt`] method resolves when the handshake
423    /// does complete, at which point subsequently opened streams and written
424    /// data will have full cryptographic protection.
425    ///
426    /// ## Outgoing
427    ///
428    /// For outgoing connections, the initial attempt to convert to a
429    /// [`Connection`] which sends 0-RTT data will proceed if the
430    /// [`crypto::ClientConfig`][crate::crypto::ClientConfig] attempts to resume
431    /// a previous TLS session. However, **the remote endpoint may not actually
432    /// _accept_ the 0-RTT data**--yet still accept the connection attempt in
433    /// general. This possibility is conveyed through the
434    /// [`Connection::accepted_0rtt`] method--when the handshake completes, it
435    /// resolves to true if the 0-RTT data was accepted and false if it was
436    /// rejected. If it was rejected, the existence of streams opened and other
437    /// application data sent prior to the handshake completing will not be
438    /// conveyed to the remote application, and local operations on them will
439    /// return `ZeroRttRejected` errors.
440    ///
441    /// A server may reject 0-RTT data at its discretion, but accepting 0-RTT
442    /// data requires the relevant resumption state to be stored in the server,
443    /// which servers may limit or lose for various reasons including not
444    /// persisting resumption state across server restarts.
445    ///
446    /// ## Incoming
447    ///
448    /// For incoming connections, conversion to 0.5-RTT will always fully
449    /// succeed. `into_0rtt` will always return `Ok` and
450    /// [`Connection::accepted_0rtt`] will always resolve to true.
451    ///
452    /// ## Security
453    ///
454    /// On outgoing connections, this enables transmission of 0-RTT data, which
455    /// is vulnerable to replay attacks, and should therefore never invoke
456    /// non-idempotent operations.
457    ///
458    /// On incoming connections, this enables transmission of 0.5-RTT data,
459    /// which may be sent before TLS client authentication has occurred, and
460    /// should therefore not be used to send data for which client
461    /// authentication is being used.
462    pub fn into_0rtt(self) -> Result<Connection, Self> {
463        let is_ok = {
464            let state = self.0.state();
465            state.conn.has_0rtt() || state.conn.side().is_server()
466        };
467        if is_ok {
468            Ok(Connection(self.0.clone()))
469        } else {
470            Err(self)
471        }
472    }
473}
474
475impl Future for Connecting {
476    type Output = Result<Connection, ConnectionError>;
477
478    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479        let mut state = self.0.try_state()?;
480
481        if state.connected {
482            return Poll::Ready(Ok(Connection(self.0.clone())));
483        }
484
485        match &state.on_connected {
486            Some(waker) if waker.will_wake(cx.waker()) => {}
487            _ => state.on_connected = Some(cx.waker().clone()),
488        }
489
490        Poll::Pending
491    }
492}
493
494impl Drop for Connecting {
495    fn drop(&mut self) {
496        implicit_close(&self.0)
497    }
498}
499
500/// A QUIC connection.
501#[derive(Debug, Clone)]
502pub struct Connection(Shared<ConnectionInner>);
503
504impl Connection {
505    conn_fn!();
506
507    /// Parameters negotiated during the handshake.
508    #[cfg(rustls)]
509    pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
510        Ok(self.0.try_state()?.handshake_data().unwrap())
511    }
512
513    /// Compute the maximum size of datagrams that may be passed to
514    /// [`send_datagram()`](Self::send_datagram).
515    ///
516    /// Returns `None` if datagrams are unsupported by the peer or disabled
517    /// locally.
518    ///
519    /// This may change over the lifetime of a connection according to variation
520    /// in the path MTU estimate. The peer can also enforce an arbitrarily small
521    /// fixed limit, but if the peer's limit is large this is guaranteed to be a
522    /// little over a kilobyte at minimum.
523    ///
524    /// Not necessarily the maximum size of received datagrams.
525    pub fn max_datagram_size(&self) -> Option<usize> {
526        self.0.state().conn.datagrams().max_size()
527    }
528
529    /// Bytes available in the outgoing datagram buffer.
530    ///
531    /// When greater than zero, calling [`send_datagram()`](Self::send_datagram)
532    /// with a datagram of at most this size is guaranteed not to cause older
533    /// datagrams to be dropped.
534    pub fn datagram_send_buffer_space(&self) -> usize {
535        self.0.state().conn.datagrams().send_buffer_space()
536    }
537
538    /// Modify the number of remotely initiated unidirectional streams that may
539    /// be concurrently open.
540    ///
541    /// No streams may be opened by the peer unless fewer than `count` are
542    /// already open. Large `count`s increase both minimum and worst-case
543    /// memory consumption.
544    pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
545        let mut state = self.0.state();
546        state.conn.set_max_concurrent_streams(Dir::Uni, count);
547        // May need to send MAX_STREAMS to make progress
548        state.wake();
549    }
550
551    /// See [`quinn_proto::TransportConfig::receive_window()`]
552    pub fn set_receive_window(&self, receive_window: VarInt) {
553        let mut state = self.0.state();
554        state.conn.set_receive_window(receive_window);
555        state.wake();
556    }
557
558    /// Modify the number of remotely initiated bidirectional streams that may
559    /// be concurrently open.
560    ///
561    /// No streams may be opened by the peer unless fewer than `count` are
562    /// already open. Large `count`s increase both minimum and worst-case
563    /// memory consumption.
564    pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
565        let mut state = self.0.state();
566        state.conn.set_max_concurrent_streams(Dir::Bi, count);
567        // May need to send MAX_STREAMS to make progress
568        state.wake();
569    }
570
571    /// Close the connection immediately.
572    ///
573    /// Pending operations will fail immediately with
574    /// [`ConnectionError::LocallyClosed`]. No more data is sent to the peer
575    /// and the peer may drop buffered data upon receiving
576    /// the CONNECTION_CLOSE frame.
577    ///
578    /// `error_code` and `reason` are not interpreted, and are provided directly
579    /// to the peer.
580    ///
581    /// `reason` will be truncated to fit in a single packet with overhead; to
582    /// improve odds that it is preserved in full, it should be kept under
583    /// 1KiB.
584    ///
585    /// # Gracefully closing a connection
586    ///
587    /// Only the peer last receiving application data can be certain that all
588    /// data is delivered. The only reliable action it can then take is to
589    /// close the connection, potentially with a custom error code. The
590    /// delivery of the final CONNECTION_CLOSE frame is very likely if both
591    /// endpoints stay online long enough, and [`Endpoint::shutdown()`] can
592    /// be used to provide sufficient time. Otherwise, the remote peer will
593    /// time out the connection, provided that the idle timeout is not
594    /// disabled.
595    ///
596    /// The sending side can not guarantee all stream data is delivered to the
597    /// remote application. It only knows the data is delivered to the QUIC
598    /// stack of the remote endpoint. Once the local side sends a
599    /// CONNECTION_CLOSE frame in response to calling [`close()`] the remote
600    /// endpoint may drop any data it received but is as yet undelivered to
601    /// the application, including data that was acknowledged as received to
602    /// the local endpoint.
603    ///
604    /// [`ConnectionError::LocallyClosed`]: ConnectionError::LocallyClosed
605    /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown
606    /// [`close()`]: Connection::close
607    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
608        self.0
609            .state()
610            .close(error_code, Bytes::copy_from_slice(reason));
611    }
612
613    /// Wait for the connection to be closed for any reason.
614    pub async fn closed(&self) -> ConnectionError {
615        let worker = self.0.state().worker.take();
616        if let Some(worker) = worker {
617            let _ = worker.await;
618        }
619
620        self.0.try_state().unwrap_err()
621    }
622
623    /// If the connection is closed, the reason why.
624    ///
625    /// Returns `None` if the connection is still open.
626    pub fn close_reason(&self) -> Option<ConnectionError> {
627        self.0.try_state().err()
628    }
629
630    fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
631        let mut state = self.0.try_state()?;
632        if let Some(bytes) = state.conn.datagrams().recv() {
633            return Poll::Ready(Ok(bytes));
634        }
635        state.datagram_received.push_back(cx.waker().clone());
636        Poll::Pending
637    }
638
639    /// Receive an application datagram.
640    pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
641        future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
642    }
643
644    fn try_send_datagram(
645        &self,
646        cx: Option<&mut Context>,
647        data: Bytes,
648    ) -> Result<(), Result<SendDatagramError, Bytes>> {
649        use quinn_proto::SendDatagramError::*;
650        let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
651        state
652            .conn
653            .datagrams()
654            .send(data, cx.is_none())
655            .map_err(|err| match err {
656                UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
657                Disabled => Ok(SendDatagramError::Disabled),
658                TooLarge => Ok(SendDatagramError::TooLarge),
659                Blocked(data) => {
660                    state
661                        .datagrams_unblocked
662                        .push_back(cx.unwrap().waker().clone());
663                    Err(data)
664                }
665            })?;
666        state.wake();
667        Ok(())
668    }
669
670    /// Transmit `data` as an unreliable, unordered application datagram.
671    ///
672    /// Application datagrams are a low-level primitive. They may be lost or
673    /// delivered out of order, and `data` must both fit inside a single
674    /// QUIC packet and be smaller than the maximum dictated by the peer.
675    pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
676        self.try_send_datagram(None, data).map_err(Result::unwrap)
677    }
678
679    /// Transmit `data` as an unreliable, unordered application datagram.
680    ///
681    /// Unlike [`send_datagram()`], this method will wait for buffer space
682    /// during congestion conditions, which effectively prioritizes old
683    /// datagrams over new datagrams.
684    ///
685    /// See [`send_datagram()`] for details.
686    ///
687    /// [`send_datagram()`]: Connection::send_datagram
688    pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
689        let mut data = Some(data);
690        future::poll_fn(
691            |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
692                Ok(()) => Poll::Ready(Ok(())),
693                Err(Ok(e)) => Poll::Ready(Err(e)),
694                Err(Err(b)) => {
695                    data.replace(b);
696                    Poll::Pending
697                }
698            },
699        )
700        .await
701    }
702
703    fn poll_open_stream(
704        &self,
705        cx: Option<&mut Context>,
706        dir: Dir,
707    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
708        let mut state = self.0.try_state()?;
709        if let Some(stream) = state.conn.streams().open(dir) {
710            Poll::Ready(Ok((
711                stream,
712                state.conn.side().is_client() && state.conn.is_handshaking(),
713            )))
714        } else {
715            if let Some(cx) = cx {
716                state.stream_available[dir as usize].push_back(cx.waker().clone());
717            }
718            Poll::Pending
719        }
720    }
721
722    /// Initiate a new outgoing unidirectional stream.
723    ///
724    /// Streams are cheap and instantaneous to open. As a consequence, the peer
725    /// won't be notified that a stream has been opened until the stream is
726    /// actually used.
727    pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
728        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
729            Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
730        } else {
731            Err(OpenStreamError::StreamsExhausted)
732        }
733    }
734
735    /// Initiate a new outgoing unidirectional stream.
736    ///
737    /// Unlike [`open_uni()`], this method will wait for the connection to allow
738    /// a new stream to be opened.
739    ///
740    /// See [`open_uni()`] for details.
741    ///
742    /// [`open_uni()`]: crate::Connection::open_uni
743    pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
744        let (stream, is_0rtt) =
745            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
746        Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
747    }
748
749    /// Initiate a new outgoing bidirectional stream.
750    ///
751    /// Streams are cheap and instantaneous to open. As a consequence, the peer
752    /// won't be notified that a stream has been opened until the stream is
753    /// actually used.
754    pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
755        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
756            Ok((
757                SendStream::new(self.0.clone(), stream, is_0rtt),
758                RecvStream::new(self.0.clone(), stream, is_0rtt),
759            ))
760        } else {
761            Err(OpenStreamError::StreamsExhausted)
762        }
763    }
764
765    /// Initiate a new outgoing bidirectional stream.
766    ///
767    /// Unlike [`open_bi()`], this method will wait for the connection to allow
768    /// a new stream to be opened.
769    ///
770    /// See [`open_bi()`] for details.
771    ///
772    /// [`open_bi()`]: crate::Connection::open_bi
773    pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
774        let (stream, is_0rtt) =
775            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
776        Ok((
777            SendStream::new(self.0.clone(), stream, is_0rtt),
778            RecvStream::new(self.0.clone(), stream, is_0rtt),
779        ))
780    }
781
782    fn poll_accept_stream(
783        &self,
784        cx: &mut Context,
785        dir: Dir,
786    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
787        let mut state = self.0.try_state()?;
788        if let Some(stream) = state.conn.streams().accept(dir) {
789            state.wake();
790            Poll::Ready(Ok((stream, state.conn.is_handshaking())))
791        } else {
792            state.stream_opened[dir as usize].push_back(cx.waker().clone());
793            Poll::Pending
794        }
795    }
796
797    /// Accept the next incoming uni-directional stream
798    pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
799        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
800        Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
801    }
802
803    /// Accept the next incoming bidirectional stream
804    ///
805    /// **Important Note**: The `Connection` that calls [`open_bi()`] must write
806    /// to its [`SendStream`] before the other `Connection` is able to
807    /// `accept_bi()`. Calling [`open_bi()`] then waiting on the [`RecvStream`]
808    /// without writing anything to [`SendStream`] will never succeed.
809    ///
810    /// [`accept_bi()`]: crate::Connection::accept_bi
811    /// [`open_bi()`]: crate::Connection::open_bi
812    /// [`SendStream`]: crate::SendStream
813    /// [`RecvStream`]: crate::RecvStream
814    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
815        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
816        Ok((
817            SendStream::new(self.0.clone(), stream, is_0rtt),
818            RecvStream::new(self.0.clone(), stream, is_0rtt),
819        ))
820    }
821
822    /// Wait for the connection to be fully established.
823    ///
824    /// For clients, the resulting value indicates if 0-RTT was accepted. For
825    /// servers, the resulting value is meaningless.
826    pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
827        future::poll_fn(|cx| {
828            let mut state = self.0.try_state()?;
829
830            if state.connected {
831                return Poll::Ready(Ok(state.conn.accepted_0rtt()));
832            }
833
834            match &state.on_connected {
835                Some(waker) if waker.will_wake(cx.waker()) => {}
836                _ => state.on_connected = Some(cx.waker().clone()),
837            }
838
839            Poll::Pending
840        })
841        .await
842    }
843}
844
845impl PartialEq for Connection {
846    fn eq(&self, other: &Self) -> bool {
847        Shared::ptr_eq(&self.0, &other.0)
848    }
849}
850
851impl Eq for Connection {}
852
853impl Drop for Connection {
854    fn drop(&mut self) {
855        implicit_close(&self.0)
856    }
857}
858
859struct Timer {
860    deadline: Option<Instant>,
861    fut: Fuse<LocalBoxFuture<'static, ()>>,
862}
863
864impl Timer {
865    fn new() -> Self {
866        Self {
867            deadline: None,
868            fut: Fuse::terminated(),
869        }
870    }
871
872    fn reset(&mut self, deadline: Option<Instant>) {
873        if let Some(deadline) = deadline {
874            if self.deadline.is_none() || self.deadline != Some(deadline) {
875                self.fut = compio_runtime::time::sleep_until(deadline)
876                    .boxed_local()
877                    .fuse();
878            }
879        } else {
880            self.fut = Fuse::terminated();
881        }
882        self.deadline = deadline;
883    }
884}
885
886impl Future for Timer {
887    type Output = ();
888
889    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
890        self.fut.poll_unpin(cx)
891    }
892}
893
894impl FusedFuture for Timer {
895    fn is_terminated(&self) -> bool {
896        self.fut.is_terminated()
897    }
898}
899
900/// Reasons why a connection might be lost
901#[derive(Debug, Error, Clone, PartialEq, Eq)]
902pub enum ConnectionError {
903    /// The peer doesn't implement any supported version
904    #[error("peer doesn't implement any supported version")]
905    VersionMismatch,
906    /// The peer violated the QUIC specification as understood by this
907    /// implementation
908    #[error(transparent)]
909    TransportError(#[from] quinn_proto::TransportError),
910    /// The peer's QUIC stack aborted the connection automatically
911    #[error("aborted by peer: {0}")]
912    ConnectionClosed(quinn_proto::ConnectionClose),
913    /// The peer closed the connection
914    #[error("closed by peer: {0}")]
915    ApplicationClosed(quinn_proto::ApplicationClose),
916    /// The peer is unable to continue processing this connection, usually due
917    /// to having restarted
918    #[error("reset by peer")]
919    Reset,
920    /// Communication with the peer has lapsed for longer than the negotiated
921    /// idle timeout
922    ///
923    /// If neither side is sending keep-alives, a connection will time out after
924    /// a long enough idle period even if the peer is still reachable. See
925    /// also [`TransportConfig::max_idle_timeout()`](quinn_proto::TransportConfig::max_idle_timeout())
926    /// and [`TransportConfig::keep_alive_interval()`](quinn_proto::TransportConfig::keep_alive_interval()).
927    #[error("timed out")]
928    TimedOut,
929    /// The local application closed the connection
930    #[error("closed")]
931    LocallyClosed,
932    /// The connection could not be created because not enough of the CID space
933    /// is available
934    ///
935    /// Try using longer connection IDs.
936    #[error("CIDs exhausted")]
937    CidsExhausted,
938}
939
940impl From<quinn_proto::ConnectionError> for ConnectionError {
941    fn from(value: quinn_proto::ConnectionError) -> Self {
942        use quinn_proto::ConnectionError::*;
943
944        match value {
945            VersionMismatch => ConnectionError::VersionMismatch,
946            TransportError(e) => ConnectionError::TransportError(e),
947            ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
948            ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
949            Reset => ConnectionError::Reset,
950            TimedOut => ConnectionError::TimedOut,
951            LocallyClosed => ConnectionError::LocallyClosed,
952            CidsExhausted => ConnectionError::CidsExhausted,
953        }
954    }
955}
956
957/// Errors that can arise when sending a datagram
958#[derive(Debug, Error, Clone, Eq, PartialEq)]
959pub enum SendDatagramError {
960    /// The peer does not support receiving datagram frames
961    #[error("datagrams not supported by peer")]
962    UnsupportedByPeer,
963    /// Datagram support is disabled locally
964    #[error("datagram support disabled")]
965    Disabled,
966    /// The datagram is larger than the connection can currently accommodate
967    ///
968    /// Indicates that the path MTU minus overhead or the limit advertised by
969    /// the peer has been exceeded.
970    #[error("datagram too large")]
971    TooLarge,
972    /// The connection was lost
973    #[error("connection lost")]
974    ConnectionLost(#[from] ConnectionError),
975}
976
977/// Errors that can arise when trying to open a stream
978#[derive(Debug, Error, Clone, Eq, PartialEq)]
979pub enum OpenStreamError {
980    /// The connection was lost
981    #[error("connection lost")]
982    ConnectionLost(#[from] ConnectionError),
983    // The streams in the given direction are currently exhausted
984    #[error("streams exhausted")]
985    StreamsExhausted,
986}
987
988#[cfg(feature = "h3")]
989pub(crate) mod h3_impl {
990    use std::sync::Arc;
991
992    use compio_buf::bytes::Buf;
993    use futures_util::ready;
994    use h3::{
995        error::Code,
996        quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
997    };
998    use h3_datagram::{
999        datagram::EncodedDatagram,
1000        quic_traits::{
1001            DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1002        },
1003    };
1004
1005    use super::*;
1006    use crate::send_stream::h3_impl::SendStream;
1007
1008    impl From<ConnectionError> for ConnectionErrorIncoming {
1009        fn from(e: ConnectionError) -> Self {
1010            use ConnectionError::*;
1011            match e {
1012                ApplicationClosed(e) => Self::ApplicationClose {
1013                    error_code: e.error_code.into_inner(),
1014                },
1015                TimedOut => Self::Timeout,
1016
1017                e => Self::Undefined(Arc::new(e)),
1018            }
1019        }
1020    }
1021
1022    impl From<ConnectionError> for StreamErrorIncoming {
1023        fn from(e: ConnectionError) -> Self {
1024            Self::ConnectionErrorIncoming {
1025                connection_error: e.into(),
1026            }
1027        }
1028    }
1029
1030    impl From<SendDatagramError> for SendDatagramErrorIncoming {
1031        fn from(e: SendDatagramError) -> Self {
1032            use SendDatagramError::*;
1033            match e {
1034                UnsupportedByPeer | Disabled => Self::NotAvailable,
1035                TooLarge => Self::TooLarge,
1036                ConnectionLost(e) => Self::ConnectionError(e.into()),
1037            }
1038        }
1039    }
1040
1041    impl<B> SendDatagram<B> for Connection
1042    where
1043        B: Buf,
1044    {
1045        fn send_datagram<T: Into<EncodedDatagram<B>>>(
1046            &mut self,
1047            data: T,
1048        ) -> Result<(), SendDatagramErrorIncoming> {
1049            let mut buf: EncodedDatagram<B> = data.into();
1050            let buf = buf.copy_to_bytes(buf.remaining());
1051            Ok(Connection::send_datagram(self, buf)?)
1052        }
1053    }
1054
1055    impl RecvDatagram for Connection {
1056        type Buffer = Bytes;
1057
1058        fn poll_incoming_datagram(
1059            &mut self,
1060            cx: &mut core::task::Context<'_>,
1061        ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1062            Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1063        }
1064    }
1065
1066    impl<B: Buf> DatagramConnectionExt<B> for Connection {
1067        type RecvDatagramHandler = Self;
1068        type SendDatagramHandler = Self;
1069
1070        fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1071            self.clone()
1072        }
1073
1074        fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1075            self.clone()
1076        }
1077    }
1078
1079    /// Bidirectional stream.
1080    pub struct BidiStream<B> {
1081        send: SendStream<B>,
1082        recv: RecvStream,
1083    }
1084
1085    impl<B> BidiStream<B> {
1086        pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1087            Self {
1088                send: SendStream::new(conn.clone(), stream, is_0rtt),
1089                recv: RecvStream::new(conn, stream, is_0rtt),
1090            }
1091        }
1092    }
1093
1094    impl<B> quic::BidiStream<B> for BidiStream<B>
1095    where
1096        B: Buf,
1097    {
1098        type RecvStream = RecvStream;
1099        type SendStream = SendStream<B>;
1100
1101        fn split(self) -> (Self::SendStream, Self::RecvStream) {
1102            (self.send, self.recv)
1103        }
1104    }
1105
1106    impl<B> quic::RecvStream for BidiStream<B>
1107    where
1108        B: Buf,
1109    {
1110        type Buf = Bytes;
1111
1112        fn poll_data(
1113            &mut self,
1114            cx: &mut Context<'_>,
1115        ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1116            self.recv.poll_data(cx)
1117        }
1118
1119        fn stop_sending(&mut self, error_code: u64) {
1120            self.recv.stop_sending(error_code)
1121        }
1122
1123        fn recv_id(&self) -> quic::StreamId {
1124            self.recv.recv_id()
1125        }
1126    }
1127
1128    impl<B> quic::SendStream<B> for BidiStream<B>
1129    where
1130        B: Buf,
1131    {
1132        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1133            self.send.poll_ready(cx)
1134        }
1135
1136        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1137            self.send.send_data(data)
1138        }
1139
1140        fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1141            self.send.poll_finish(cx)
1142        }
1143
1144        fn reset(&mut self, reset_code: u64) {
1145            self.send.reset(reset_code)
1146        }
1147
1148        fn send_id(&self) -> quic::StreamId {
1149            self.send.send_id()
1150        }
1151    }
1152
1153    impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1154    where
1155        B: Buf,
1156    {
1157        fn poll_send<D: Buf>(
1158            &mut self,
1159            cx: &mut Context<'_>,
1160            buf: &mut D,
1161        ) -> Poll<Result<usize, StreamErrorIncoming>> {
1162            self.send.poll_send(cx, buf)
1163        }
1164    }
1165
1166    /// Stream opener.
1167    #[derive(Clone)]
1168    pub struct OpenStreams(Connection);
1169
1170    impl<B> quic::OpenStreams<B> for OpenStreams
1171    where
1172        B: Buf,
1173    {
1174        type BidiStream = BidiStream<B>;
1175        type SendStream = SendStream<B>;
1176
1177        fn poll_open_bidi(
1178            &mut self,
1179            cx: &mut Context<'_>,
1180        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1181            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1182            Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1183        }
1184
1185        fn poll_open_send(
1186            &mut self,
1187            cx: &mut Context<'_>,
1188        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1189            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1190            Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1191        }
1192
1193        fn close(&mut self, code: Code, reason: &[u8]) {
1194            self.0
1195                .close(code.value().try_into().expect("invalid code"), reason)
1196        }
1197    }
1198
1199    impl<B> quic::OpenStreams<B> for Connection
1200    where
1201        B: Buf,
1202    {
1203        type BidiStream = BidiStream<B>;
1204        type SendStream = SendStream<B>;
1205
1206        fn poll_open_bidi(
1207            &mut self,
1208            cx: &mut Context<'_>,
1209        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1210            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1211            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1212        }
1213
1214        fn poll_open_send(
1215            &mut self,
1216            cx: &mut Context<'_>,
1217        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1218            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1219            Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1220        }
1221
1222        fn close(&mut self, code: Code, reason: &[u8]) {
1223            Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1224        }
1225    }
1226
1227    impl<B> quic::Connection<B> for Connection
1228    where
1229        B: Buf,
1230    {
1231        type OpenStreams = OpenStreams;
1232        type RecvStream = RecvStream;
1233
1234        fn poll_accept_recv(
1235            &mut self,
1236            cx: &mut std::task::Context<'_>,
1237        ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1238            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1239            Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1240        }
1241
1242        fn poll_accept_bidi(
1243            &mut self,
1244            cx: &mut std::task::Context<'_>,
1245        ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1246            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1247            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1248        }
1249
1250        fn opener(&self) -> Self::OpenStreams {
1251            OpenStreams(self.clone())
1252        }
1253    }
1254}