Skip to main content

compio_quic/
connection.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    net::{IpAddr, SocketAddr},
5    pin::{Pin, pin},
6    task::{Context, Poll, Waker},
7    time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15    FutureExt, StreamExt,
16    future::{self, Fuse, FusedFuture, LocalBoxFuture},
17    select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22    ConnectionHandle, ConnectionStats, Dir, EndpointEvent, Side, StreamEvent, StreamId, VarInt,
23    congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29    RecvStream, SendStream, Socket,
30    sync::{
31        mutex_blocking::{Mutex, MutexGuard},
32        shared::Shared,
33    },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38    Close(VarInt, Bytes),
39    Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44    pub(crate) conn: quinn_proto::Connection,
45    pub(crate) error: Option<ConnectionError>,
46    connected: bool,
47    worker: Option<JoinHandle<()>>,
48    poller: Option<Waker>,
49    on_connected: Option<Waker>,
50    on_handshake_data: Option<Waker>,
51    datagram_received: VecDeque<Waker>,
52    datagrams_unblocked: VecDeque<Waker>,
53    stream_opened: [VecDeque<Waker>; 2],
54    stream_available: [VecDeque<Waker>; 2],
55    pub(crate) writable: HashMap<StreamId, Waker>,
56    pub(crate) readable: HashMap<StreamId, Waker>,
57    pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61    fn terminate(&mut self, reason: ConnectionError) {
62        self.error = Some(reason);
63        self.connected = false;
64
65        if let Some(waker) = self.on_handshake_data.take() {
66            waker.wake()
67        }
68        if let Some(waker) = self.on_connected.take() {
69            waker.wake()
70        }
71        self.datagram_received.drain(..).for_each(Waker::wake);
72        self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73        for e in &mut self.stream_opened {
74            e.drain(..).for_each(Waker::wake);
75        }
76        for e in &mut self.stream_available {
77            e.drain(..).for_each(Waker::wake);
78        }
79        wake_all_streams(&mut self.writable);
80        wake_all_streams(&mut self.readable);
81        wake_all_streams(&mut self.stopped);
82    }
83
84    fn close(&mut self, error_code: VarInt, reason: Bytes) {
85        self.conn.close(Instant::now(), error_code, reason);
86        self.terminate(ConnectionError::LocallyClosed);
87        self.wake();
88    }
89
90    pub(crate) fn wake(&mut self) {
91        if let Some(waker) = self.poller.take() {
92            waker.wake()
93        }
94    }
95
96    #[cfg(rustls)]
97    fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98        self.conn
99            .crypto_session()
100            .handshake_data()
101            .map(|data| data.downcast::<HandshakeData>().unwrap())
102    }
103
104    pub(crate) fn check_0rtt(&self) -> bool {
105        self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106    }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110    if let Some(waker) = wakers.remove(&stream) {
111        waker.wake();
112    }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116    wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121    state: Mutex<ConnectionState>,
122    handle: ConnectionHandle,
123    socket: Socket,
124    events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125    events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129    if Shared::strong_count(this) == 2 {
130        this.state().close(0u32.into(), Bytes::new())
131    }
132}
133
134impl ConnectionInner {
135    fn new(
136        handle: ConnectionHandle,
137        conn: quinn_proto::Connection,
138        socket: Socket,
139        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140        events_rx: Receiver<ConnectionEvent>,
141    ) -> Self {
142        Self {
143            state: Mutex::new(ConnectionState {
144                conn,
145                connected: false,
146                error: None,
147                worker: None,
148                poller: None,
149                on_connected: None,
150                on_handshake_data: None,
151                datagram_received: VecDeque::new(),
152                datagrams_unblocked: VecDeque::new(),
153                stream_opened: [VecDeque::new(), VecDeque::new()],
154                stream_available: [VecDeque::new(), VecDeque::new()],
155                writable: HashMap::default(),
156                readable: HashMap::default(),
157                stopped: HashMap::default(),
158            }),
159            handle,
160            socket,
161            events_tx,
162            events_rx,
163        }
164    }
165
166    #[inline]
167    pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168        self.state.lock()
169    }
170
171    #[inline]
172    pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173        let state = self.state();
174        if let Some(error) = &state.error {
175            Err(error.clone())
176        } else {
177            Ok(state)
178        }
179    }
180
181    async fn run(&self) {
182        let mut poller = stream::poll_fn(|cx| {
183            let mut state = self.state();
184            let ready = state.poller.is_none();
185            match &state.poller {
186                Some(waker) if waker.will_wake(cx.waker()) => {}
187                _ => state.poller = Some(cx.waker().clone()),
188            };
189            if ready {
190                Poll::Ready(Some(()))
191            } else {
192                Poll::Pending
193            }
194        })
195        .fuse();
196
197        let mut timer = Timer::new();
198        let mut event_stream = self.events_rx.stream().ready_chunks(100);
199        let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200        let mut transmit_fut = pin!(Fuse::terminated());
201
202        loop {
203            let mut state = select! {
204                _ = poller.select_next_some() => self.state(),
205                _ = timer => {
206                    timer.reset(None);
207                    let mut state = self.state();
208                    state.conn.handle_timeout(Instant::now());
209                    state
210                }
211                events = event_stream.select_next_some() => {
212                    let mut state = self.state();
213                    for event in events {
214                        match event {
215                            ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216                            ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217                        }
218                    }
219                    state
220                },
221                buf = transmit_fut => {
222                    // The following line is required to avoid "type annotations needed" error
223                    let mut buf: Vec<_> = buf;
224                    buf.clear();
225                    send_buf = Some(buf);
226                    self.state()
227                },
228            };
229
230            if let Some(mut buf) = send_buf.take() {
231                if let Some(transmit) = state.conn.poll_transmit(
232                    Instant::now(),
233                    self.socket.max_gso_segments(),
234                    &mut buf,
235                ) {
236                    transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237                } else {
238                    send_buf = Some(buf);
239                }
240            }
241
242            timer.reset(state.conn.poll_timeout());
243
244            while let Some(event) = state.conn.poll_endpoint_events() {
245                let _ = self.events_tx.send((self.handle, event));
246            }
247
248            while let Some(event) = state.conn.poll() {
249                use quinn_proto::Event::*;
250                match event {
251                    HandshakeDataReady => {
252                        if let Some(waker) = state.on_handshake_data.take() {
253                            waker.wake()
254                        }
255                    }
256                    Connected => {
257                        state.connected = true;
258                        if let Some(waker) = state.on_connected.take() {
259                            waker.wake()
260                        }
261                        if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262                            // Wake up rejected 0-RTT streams so they can fail immediately with
263                            // `ZeroRttRejected` errors.
264                            wake_all_streams(&mut state.writable);
265                            wake_all_streams(&mut state.readable);
266                            wake_all_streams(&mut state.stopped);
267                        }
268                    }
269                    ConnectionLost { reason } => state.terminate(reason.into()),
270                    Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271                    Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272                    Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273                    Stream(StreamEvent::Stopped { id, .. }) => {
274                        wake_stream(id, &mut state.stopped);
275                        wake_stream(id, &mut state.writable);
276                    }
277                    Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278                        .drain(..)
279                        .for_each(Waker::wake),
280                    Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281                        .drain(..)
282                        .for_each(Waker::wake),
283                    DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284                    DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285                }
286            }
287
288            if state.conn.is_drained() {
289                break;
290            }
291        }
292
293        // Break the reference cycle.
294        if let Some(worker) = self.state().worker.take() {
295            worker.detach();
296        }
297    }
298}
299
300macro_rules! conn_fn {
301    () => {
302        /// The side of the connection (client or server)
303        pub fn side(&self) -> Side {
304            self.0.state().conn.side()
305        }
306
307        /// The local IP address which was used when the peer established
308        /// the connection.
309        ///
310        /// This can be different from the address the endpoint is bound to, in case
311        /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
312        ///
313        /// This will return `None` for clients, or when the platform does not
314        /// expose this information.
315        pub fn local_ip(&self) -> Option<IpAddr> {
316            self.0.state().conn.local_ip()
317        }
318
319        /// The peer's UDP address.
320        ///
321        /// Will panic if called after `poll` has returned `Ready`.
322        pub fn remote_address(&self) -> SocketAddr {
323            self.0.state().conn.remote_address()
324        }
325
326        /// Current best estimate of this connection's latency (round-trip-time).
327        pub fn rtt(&self) -> Duration {
328            self.0.state().conn.rtt()
329        }
330
331        /// Connection statistics.
332        pub fn stats(&self) -> ConnectionStats {
333            self.0.state().conn.stats()
334        }
335
336        /// Current state of the congestion control algorithm. (For debugging
337        /// purposes)
338        pub fn congestion_state(&self) -> Box<dyn Controller> {
339            self.0.state().conn.congestion_state().clone_box()
340        }
341
342        /// Cryptographic identity of the peer.
343        pub fn peer_identity(
344            &self,
345        ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
346            self.0
347                .state()
348                .conn
349                .crypto_session()
350                .peer_identity()
351                .map(|v| v.downcast().unwrap())
352        }
353
354        /// A stable identifier for this connection
355        ///
356        /// Peer addresses and connection IDs can change, but this value will remain
357        /// fixed for the lifetime of the connection.
358        pub fn stable_id(&self) -> usize {
359            Shared::as_ptr(&self.0) as usize
360        }
361
362        /// Derive keying material from this connection's TLS session secrets.
363        ///
364        /// When both peers call this method with the same `label` and `context`
365        /// arguments and `output` buffers of equal length, they will get the
366        /// same sequence of bytes in `output`. These bytes are cryptographically
367        /// strong and pseudorandom, and are suitable for use as keying material.
368        ///
369        /// This function fails if called with an empty `output` or called prior to
370        /// the handshake completing.
371        ///
372        /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information.
373        pub fn export_keying_material(
374            &self,
375            output: &mut [u8],
376            label: &[u8],
377            context: &[u8],
378        ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
379            self.0
380                .state()
381                .conn
382                .crypto_session()
383                .export_keying_material(output, label, context)
384        }
385    };
386}
387
388/// In-progress connection attempt future
389#[derive(Debug)]
390#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
391pub struct Connecting(Shared<ConnectionInner>);
392
393impl Connecting {
394    conn_fn!();
395
396    pub(crate) fn new(
397        handle: ConnectionHandle,
398        conn: quinn_proto::Connection,
399        socket: Socket,
400        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
401        events_rx: Receiver<ConnectionEvent>,
402    ) -> Self {
403        let inner = Shared::new(ConnectionInner::new(
404            handle, conn, socket, events_tx, events_rx,
405        ));
406        let worker = compio_runtime::spawn({
407            let inner = inner.clone();
408            async move { inner.run().await }.in_current_span()
409        });
410        inner.state().worker = Some(worker);
411        Self(inner)
412    }
413
414    /// Parameters negotiated during the handshake.
415    #[cfg(rustls)]
416    pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
417        future::poll_fn(|cx| {
418            let mut state = self.0.try_state()?;
419            if let Some(data) = state.handshake_data() {
420                return Poll::Ready(Ok(data));
421            }
422
423            match &state.on_handshake_data {
424                Some(waker) if waker.will_wake(cx.waker()) => {}
425                _ => state.on_handshake_data = Some(cx.waker().clone()),
426            }
427
428            Poll::Pending
429        })
430        .await
431    }
432
433    /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened
434    /// security.
435    ///
436    /// Returns `Ok` immediately if the local endpoint is able to attempt
437    /// sending 0/0.5-RTT data. If so, the returned [`Connection`] can be used
438    /// to send application data without waiting for the rest of the handshake
439    /// to complete, at the cost of weakened cryptographic security guarantees.
440    /// The [`Connection::accepted_0rtt`] method resolves when the handshake
441    /// does complete, at which point subsequently opened streams and written
442    /// data will have full cryptographic protection.
443    ///
444    /// ## Outgoing
445    ///
446    /// For outgoing connections, the initial attempt to convert to a
447    /// [`Connection`] which sends 0-RTT data will proceed if the
448    /// [`crypto::ClientConfig`][crate::crypto::ClientConfig] attempts to resume
449    /// a previous TLS session. However, **the remote endpoint may not actually
450    /// _accept_ the 0-RTT data**--yet still accept the connection attempt in
451    /// general. This possibility is conveyed through the
452    /// [`Connection::accepted_0rtt`] method--when the handshake completes, it
453    /// resolves to true if the 0-RTT data was accepted and false if it was
454    /// rejected. If it was rejected, the existence of streams opened and other
455    /// application data sent prior to the handshake completing will not be
456    /// conveyed to the remote application, and local operations on them will
457    /// return `ZeroRttRejected` errors.
458    ///
459    /// A server may reject 0-RTT data at its discretion, but accepting 0-RTT
460    /// data requires the relevant resumption state to be stored in the server,
461    /// which servers may limit or lose for various reasons including not
462    /// persisting resumption state across server restarts.
463    ///
464    /// ## Incoming
465    ///
466    /// For incoming connections, conversion to 0.5-RTT will always fully
467    /// succeed. `into_0rtt` will always return `Ok` and
468    /// [`Connection::accepted_0rtt`] will always resolve to true.
469    ///
470    /// ## Security
471    ///
472    /// On outgoing connections, this enables transmission of 0-RTT data, which
473    /// is vulnerable to replay attacks, and should therefore never invoke
474    /// non-idempotent operations.
475    ///
476    /// On incoming connections, this enables transmission of 0.5-RTT data,
477    /// which may be sent before TLS client authentication has occurred, and
478    /// should therefore not be used to send data for which client
479    /// authentication is being used.
480    pub fn into_0rtt(self) -> Result<Connection, Self> {
481        let is_ok = {
482            let state = self.0.state();
483            state.conn.has_0rtt() || state.conn.side().is_server()
484        };
485        if is_ok {
486            Ok(Connection(self.0.clone()))
487        } else {
488            Err(self)
489        }
490    }
491}
492
493impl Future for Connecting {
494    type Output = Result<Connection, ConnectionError>;
495
496    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
497        let mut state = self.0.try_state()?;
498
499        if state.connected {
500            return Poll::Ready(Ok(Connection(self.0.clone())));
501        }
502
503        match &state.on_connected {
504            Some(waker) if waker.will_wake(cx.waker()) => {}
505            _ => state.on_connected = Some(cx.waker().clone()),
506        }
507
508        Poll::Pending
509    }
510}
511
512impl Drop for Connecting {
513    fn drop(&mut self) {
514        implicit_close(&self.0)
515    }
516}
517
518/// A QUIC connection.
519#[derive(Debug, Clone)]
520pub struct Connection(Shared<ConnectionInner>);
521
522impl Connection {
523    conn_fn!();
524
525    /// Update traffic keys spontaneously
526    ///
527    /// This primarily exists for testing purposes.
528    pub fn force_key_update(&self) {
529        self.0.state().conn.force_key_update()
530    }
531
532    /// Parameters negotiated during the handshake.
533    #[cfg(rustls)]
534    pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
535        Ok(self.0.try_state()?.handshake_data().unwrap())
536    }
537
538    /// Compute the maximum size of datagrams that may be passed to
539    /// [`send_datagram()`](Self::send_datagram).
540    ///
541    /// Returns `None` if datagrams are unsupported by the peer or disabled
542    /// locally.
543    ///
544    /// This may change over the lifetime of a connection according to variation
545    /// in the path MTU estimate. The peer can also enforce an arbitrarily small
546    /// fixed limit, but if the peer's limit is large this is guaranteed to be a
547    /// little over a kilobyte at minimum.
548    ///
549    /// Not necessarily the maximum size of received datagrams.
550    pub fn max_datagram_size(&self) -> Option<usize> {
551        self.0.state().conn.datagrams().max_size()
552    }
553
554    /// Bytes available in the outgoing datagram buffer.
555    ///
556    /// When greater than zero, calling [`send_datagram()`](Self::send_datagram)
557    /// with a datagram of at most this size is guaranteed not to cause older
558    /// datagrams to be dropped.
559    pub fn datagram_send_buffer_space(&self) -> usize {
560        self.0.state().conn.datagrams().send_buffer_space()
561    }
562
563    /// Modify the number of remotely initiated unidirectional streams that may
564    /// be concurrently open.
565    ///
566    /// No streams may be opened by the peer unless fewer than `count` are
567    /// already open. Large `count`s increase both minimum and worst-case
568    /// memory consumption.
569    pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
570        let mut state = self.0.state();
571        state.conn.set_max_concurrent_streams(Dir::Uni, count);
572        // May need to send MAX_STREAMS to make progress
573        state.wake();
574    }
575
576    /// See [`quinn_proto::TransportConfig::send_window()`]
577    pub fn set_send_window(&self, send_window: u64) {
578        let mut state = self.0.state();
579        state.conn.set_send_window(send_window);
580        state.wake();
581    }
582
583    /// See [`quinn_proto::TransportConfig::receive_window()`]
584    pub fn set_receive_window(&self, receive_window: VarInt) {
585        let mut state = self.0.state();
586        state.conn.set_receive_window(receive_window);
587        state.wake();
588    }
589
590    /// Modify the number of remotely initiated bidirectional streams that may
591    /// be concurrently open.
592    ///
593    /// No streams may be opened by the peer unless fewer than `count` are
594    /// already open. Large `count`s increase both minimum and worst-case
595    /// memory consumption.
596    pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
597        let mut state = self.0.state();
598        state.conn.set_max_concurrent_streams(Dir::Bi, count);
599        // May need to send MAX_STREAMS to make progress
600        state.wake();
601    }
602
603    /// Close the connection immediately.
604    ///
605    /// Pending operations will fail immediately with
606    /// [`ConnectionError::LocallyClosed`]. No more data is sent to the peer
607    /// and the peer may drop buffered data upon receiving
608    /// the CONNECTION_CLOSE frame.
609    ///
610    /// `error_code` and `reason` are not interpreted, and are provided directly
611    /// to the peer.
612    ///
613    /// `reason` will be truncated to fit in a single packet with overhead; to
614    /// improve odds that it is preserved in full, it should be kept under
615    /// 1KiB.
616    ///
617    /// # Gracefully closing a connection
618    ///
619    /// Only the peer last receiving application data can be certain that all
620    /// data is delivered. The only reliable action it can then take is to
621    /// close the connection, potentially with a custom error code. The
622    /// delivery of the final CONNECTION_CLOSE frame is very likely if both
623    /// endpoints stay online long enough, and [`Endpoint::shutdown()`] can
624    /// be used to provide sufficient time. Otherwise, the remote peer will
625    /// time out the connection, provided that the idle timeout is not
626    /// disabled.
627    ///
628    /// The sending side can not guarantee all stream data is delivered to the
629    /// remote application. It only knows the data is delivered to the QUIC
630    /// stack of the remote endpoint. Once the local side sends a
631    /// CONNECTION_CLOSE frame in response to calling [`close()`] the remote
632    /// endpoint may drop any data it received but is as yet undelivered to
633    /// the application, including data that was acknowledged as received to
634    /// the local endpoint.
635    ///
636    /// [`ConnectionError::LocallyClosed`]: ConnectionError::LocallyClosed
637    /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown
638    /// [`close()`]: Connection::close
639    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
640        self.0
641            .state()
642            .close(error_code, Bytes::copy_from_slice(reason));
643    }
644
645    /// Wait for the connection to be closed for any reason.
646    pub async fn closed(&self) -> ConnectionError {
647        let worker = self.0.state().worker.take();
648        if let Some(worker) = worker {
649            let _ = worker.await;
650        }
651
652        self.0.try_state().unwrap_err()
653    }
654
655    /// If the connection is closed, the reason why.
656    ///
657    /// Returns `None` if the connection is still open.
658    pub fn close_reason(&self) -> Option<ConnectionError> {
659        self.0.try_state().err()
660    }
661
662    fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
663        let mut state = self.0.try_state()?;
664        if let Some(bytes) = state.conn.datagrams().recv() {
665            return Poll::Ready(Ok(bytes));
666        }
667        state.datagram_received.push_back(cx.waker().clone());
668        Poll::Pending
669    }
670
671    /// Try to receive an application datagram. Returns None if no datagram is
672    /// available.
673    pub fn try_recv_datagram(&self) -> Result<Option<Bytes>, ConnectionError> {
674        let mut state = self.0.try_state()?;
675        Ok(state.conn.datagrams().recv())
676    }
677
678    /// Receive an application datagram.
679    pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
680        future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
681    }
682
683    fn try_send_datagram(
684        &self,
685        cx: Option<&mut Context>,
686        data: Bytes,
687    ) -> Result<(), Result<SendDatagramError, Bytes>> {
688        use quinn_proto::SendDatagramError::*;
689        let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
690        state
691            .conn
692            .datagrams()
693            .send(data, cx.is_none())
694            .map_err(|err| match err {
695                UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
696                Disabled => Ok(SendDatagramError::Disabled),
697                TooLarge => Ok(SendDatagramError::TooLarge),
698                Blocked(data) => {
699                    state
700                        .datagrams_unblocked
701                        .push_back(cx.unwrap().waker().clone());
702                    Err(data)
703                }
704            })?;
705        state.wake();
706        Ok(())
707    }
708
709    /// Transmit `data` as an unreliable, unordered application datagram.
710    ///
711    /// Application datagrams are a low-level primitive. They may be lost or
712    /// delivered out of order, and `data` must both fit inside a single
713    /// QUIC packet and be smaller than the maximum dictated by the peer.
714    pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
715        self.try_send_datagram(None, data).map_err(Result::unwrap)
716    }
717
718    /// Transmit `data` as an unreliable, unordered application datagram.
719    ///
720    /// Unlike [`send_datagram()`], this method will wait for buffer space
721    /// during congestion conditions, which effectively prioritizes old
722    /// datagrams over new datagrams.
723    ///
724    /// See [`send_datagram()`] for details.
725    ///
726    /// [`send_datagram()`]: Connection::send_datagram
727    pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
728        let mut data = Some(data);
729        future::poll_fn(
730            |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
731                Ok(()) => Poll::Ready(Ok(())),
732                Err(Ok(e)) => Poll::Ready(Err(e)),
733                Err(Err(b)) => {
734                    data.replace(b);
735                    Poll::Pending
736                }
737            },
738        )
739        .await
740    }
741
742    fn poll_open_stream(
743        &self,
744        cx: Option<&mut Context>,
745        dir: Dir,
746    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
747        let mut state = self.0.try_state()?;
748        if let Some(stream) = state.conn.streams().open(dir) {
749            Poll::Ready(Ok((
750                stream,
751                state.conn.side().is_client() && state.conn.is_handshaking(),
752            )))
753        } else {
754            if let Some(cx) = cx {
755                state.stream_available[dir as usize].push_back(cx.waker().clone());
756            }
757            Poll::Pending
758        }
759    }
760
761    /// Initiate a new outgoing unidirectional stream.
762    ///
763    /// Streams are cheap and instantaneous to open. As a consequence, the peer
764    /// won't be notified that a stream has been opened until the stream is
765    /// actually used.
766    pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
767        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
768            Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
769        } else {
770            Err(OpenStreamError::StreamsExhausted)
771        }
772    }
773
774    /// Initiate a new outgoing unidirectional stream.
775    ///
776    /// Unlike [`open_uni()`], this method will wait for the connection to allow
777    /// a new stream to be opened.
778    ///
779    /// See [`open_uni()`] for details.
780    ///
781    /// [`open_uni()`]: crate::Connection::open_uni
782    pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
783        let (stream, is_0rtt) =
784            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
785        Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
786    }
787
788    /// Initiate a new outgoing bidirectional stream.
789    ///
790    /// Streams are cheap and instantaneous to open. As a consequence, the peer
791    /// won't be notified that a stream has been opened until the stream is
792    /// actually used.
793    pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
794        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
795            Ok((
796                SendStream::new(self.0.clone(), stream, is_0rtt),
797                RecvStream::new(self.0.clone(), stream, is_0rtt),
798            ))
799        } else {
800            Err(OpenStreamError::StreamsExhausted)
801        }
802    }
803
804    /// Initiate a new outgoing bidirectional stream.
805    ///
806    /// Unlike [`open_bi()`], this method will wait for the connection to allow
807    /// a new stream to be opened.
808    ///
809    /// See [`open_bi()`] for details.
810    ///
811    /// [`open_bi()`]: crate::Connection::open_bi
812    pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
813        let (stream, is_0rtt) =
814            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
815        Ok((
816            SendStream::new(self.0.clone(), stream, is_0rtt),
817            RecvStream::new(self.0.clone(), stream, is_0rtt),
818        ))
819    }
820
821    fn poll_accept_stream(
822        &self,
823        cx: &mut Context,
824        dir: Dir,
825    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
826        let mut state = self.0.try_state()?;
827        if let Some(stream) = state.conn.streams().accept(dir) {
828            state.wake();
829            Poll::Ready(Ok((stream, state.conn.is_handshaking())))
830        } else {
831            state.stream_opened[dir as usize].push_back(cx.waker().clone());
832            Poll::Pending
833        }
834    }
835
836    /// Accept the next incoming uni-directional stream
837    pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
838        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
839        Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
840    }
841
842    /// Accept the next incoming bidirectional stream
843    ///
844    /// **Important Note**: The `Connection` that calls [`open_bi()`] must write
845    /// to its [`SendStream`] before the other `Connection` is able to
846    /// `accept_bi()`. Calling [`open_bi()`] then waiting on the [`RecvStream`]
847    /// without writing anything to [`SendStream`] will never succeed.
848    ///
849    /// [`accept_bi()`]: crate::Connection::accept_bi
850    /// [`open_bi()`]: crate::Connection::open_bi
851    /// [`SendStream`]: crate::SendStream
852    /// [`RecvStream`]: crate::RecvStream
853    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
854        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
855        Ok((
856            SendStream::new(self.0.clone(), stream, is_0rtt),
857            RecvStream::new(self.0.clone(), stream, is_0rtt),
858        ))
859    }
860
861    /// Wait for the connection to be fully established.
862    ///
863    /// For clients, the resulting value indicates if 0-RTT was accepted. For
864    /// servers, the resulting value is meaningless.
865    pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
866        future::poll_fn(|cx| {
867            let mut state = self.0.try_state()?;
868
869            if state.connected {
870                return Poll::Ready(Ok(state.conn.accepted_0rtt()));
871            }
872
873            match &state.on_connected {
874                Some(waker) if waker.will_wake(cx.waker()) => {}
875                _ => state.on_connected = Some(cx.waker().clone()),
876            }
877
878            Poll::Pending
879        })
880        .await
881    }
882}
883
884impl PartialEq for Connection {
885    fn eq(&self, other: &Self) -> bool {
886        Shared::ptr_eq(&self.0, &other.0)
887    }
888}
889
890impl Eq for Connection {}
891
892impl Drop for Connection {
893    fn drop(&mut self) {
894        implicit_close(&self.0)
895    }
896}
897
898struct Timer {
899    deadline: Option<Instant>,
900    fut: Fuse<LocalBoxFuture<'static, ()>>,
901}
902
903impl Timer {
904    fn new() -> Self {
905        Self {
906            deadline: None,
907            fut: Fuse::terminated(),
908        }
909    }
910
911    fn reset(&mut self, deadline: Option<Instant>) {
912        if let Some(deadline) = deadline {
913            if self.deadline.is_none() || self.deadline != Some(deadline) {
914                self.fut = compio_runtime::time::sleep_until(deadline)
915                    .boxed_local()
916                    .fuse();
917            }
918        } else {
919            self.fut = Fuse::terminated();
920        }
921        self.deadline = deadline;
922    }
923}
924
925impl Future for Timer {
926    type Output = ();
927
928    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
929        self.fut.poll_unpin(cx)
930    }
931}
932
933impl FusedFuture for Timer {
934    fn is_terminated(&self) -> bool {
935        self.fut.is_terminated()
936    }
937}
938
939/// Reasons why a connection might be lost
940#[derive(Debug, Error, Clone, PartialEq, Eq)]
941pub enum ConnectionError {
942    /// The peer doesn't implement any supported version
943    #[error("peer doesn't implement any supported version")]
944    VersionMismatch,
945    /// The peer violated the QUIC specification as understood by this
946    /// implementation
947    #[error(transparent)]
948    TransportError(#[from] quinn_proto::TransportError),
949    /// The peer's QUIC stack aborted the connection automatically
950    #[error("aborted by peer: {0}")]
951    ConnectionClosed(quinn_proto::ConnectionClose),
952    /// The peer closed the connection
953    #[error("closed by peer: {0}")]
954    ApplicationClosed(quinn_proto::ApplicationClose),
955    /// The peer is unable to continue processing this connection, usually due
956    /// to having restarted
957    #[error("reset by peer")]
958    Reset,
959    /// Communication with the peer has lapsed for longer than the negotiated
960    /// idle timeout
961    ///
962    /// If neither side is sending keep-alives, a connection will time out after
963    /// a long enough idle period even if the peer is still reachable. See
964    /// also [`TransportConfig::max_idle_timeout()`](quinn_proto::TransportConfig::max_idle_timeout())
965    /// and [`TransportConfig::keep_alive_interval()`](quinn_proto::TransportConfig::keep_alive_interval()).
966    #[error("timed out")]
967    TimedOut,
968    /// The local application closed the connection
969    #[error("closed")]
970    LocallyClosed,
971    /// The connection could not be created because not enough of the CID space
972    /// is available
973    ///
974    /// Try using longer connection IDs.
975    #[error("CIDs exhausted")]
976    CidsExhausted,
977}
978
979impl From<quinn_proto::ConnectionError> for ConnectionError {
980    fn from(value: quinn_proto::ConnectionError) -> Self {
981        use quinn_proto::ConnectionError::*;
982
983        match value {
984            VersionMismatch => ConnectionError::VersionMismatch,
985            TransportError(e) => ConnectionError::TransportError(e),
986            ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
987            ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
988            Reset => ConnectionError::Reset,
989            TimedOut => ConnectionError::TimedOut,
990            LocallyClosed => ConnectionError::LocallyClosed,
991            CidsExhausted => ConnectionError::CidsExhausted,
992        }
993    }
994}
995
996/// Errors that can arise when sending a datagram
997#[derive(Debug, Error, Clone, Eq, PartialEq)]
998pub enum SendDatagramError {
999    /// The peer does not support receiving datagram frames
1000    #[error("datagrams not supported by peer")]
1001    UnsupportedByPeer,
1002    /// Datagram support is disabled locally
1003    #[error("datagram support disabled")]
1004    Disabled,
1005    /// The datagram is larger than the connection can currently accommodate
1006    ///
1007    /// Indicates that the path MTU minus overhead or the limit advertised by
1008    /// the peer has been exceeded.
1009    #[error("datagram too large")]
1010    TooLarge,
1011    /// The connection was lost
1012    #[error("connection lost")]
1013    ConnectionLost(#[from] ConnectionError),
1014}
1015
1016/// Errors that can arise when trying to open a stream
1017#[derive(Debug, Error, Clone, Eq, PartialEq)]
1018pub enum OpenStreamError {
1019    /// The connection was lost
1020    #[error("connection lost")]
1021    ConnectionLost(#[from] ConnectionError),
1022    /// The streams in the given direction are currently exhausted
1023    #[error("streams exhausted")]
1024    StreamsExhausted,
1025}
1026
1027#[cfg(feature = "h3")]
1028pub(crate) mod h3_impl {
1029    use std::sync::Arc;
1030
1031    use compio_buf::bytes::Buf;
1032    use futures_util::ready;
1033    use h3::{
1034        error::Code,
1035        quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
1036    };
1037    use h3_datagram::{
1038        datagram::EncodedDatagram,
1039        quic_traits::{
1040            DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1041        },
1042    };
1043
1044    use super::*;
1045    use crate::send_stream::h3_impl::SendStream;
1046
1047    impl From<ConnectionError> for ConnectionErrorIncoming {
1048        fn from(e: ConnectionError) -> Self {
1049            use ConnectionError::*;
1050            match e {
1051                ApplicationClosed(e) => Self::ApplicationClose {
1052                    error_code: e.error_code.into_inner(),
1053                },
1054                TimedOut => Self::Timeout,
1055
1056                e => Self::Undefined(Arc::new(e)),
1057            }
1058        }
1059    }
1060
1061    impl From<ConnectionError> for StreamErrorIncoming {
1062        fn from(e: ConnectionError) -> Self {
1063            Self::ConnectionErrorIncoming {
1064                connection_error: e.into(),
1065            }
1066        }
1067    }
1068
1069    impl From<SendDatagramError> for SendDatagramErrorIncoming {
1070        fn from(e: SendDatagramError) -> Self {
1071            use SendDatagramError::*;
1072            match e {
1073                UnsupportedByPeer | Disabled => Self::NotAvailable,
1074                TooLarge => Self::TooLarge,
1075                ConnectionLost(e) => Self::ConnectionError(e.into()),
1076            }
1077        }
1078    }
1079
1080    impl<B> SendDatagram<B> for Connection
1081    where
1082        B: Buf,
1083    {
1084        fn send_datagram<T: Into<EncodedDatagram<B>>>(
1085            &mut self,
1086            data: T,
1087        ) -> Result<(), SendDatagramErrorIncoming> {
1088            let mut buf: EncodedDatagram<B> = data.into();
1089            let buf = buf.copy_to_bytes(buf.remaining());
1090            Ok(Connection::send_datagram(self, buf)?)
1091        }
1092    }
1093
1094    impl RecvDatagram for Connection {
1095        type Buffer = Bytes;
1096
1097        fn poll_incoming_datagram(
1098            &mut self,
1099            cx: &mut core::task::Context<'_>,
1100        ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1101            Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1102        }
1103    }
1104
1105    impl<B: Buf> DatagramConnectionExt<B> for Connection {
1106        type RecvDatagramHandler = Self;
1107        type SendDatagramHandler = Self;
1108
1109        fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1110            self.clone()
1111        }
1112
1113        fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1114            self.clone()
1115        }
1116    }
1117
1118    /// Bidirectional stream.
1119    pub struct BidiStream<B> {
1120        send: SendStream<B>,
1121        recv: RecvStream,
1122    }
1123
1124    impl<B> BidiStream<B> {
1125        pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1126            Self {
1127                send: SendStream::new(conn.clone(), stream, is_0rtt),
1128                recv: RecvStream::new(conn, stream, is_0rtt),
1129            }
1130        }
1131    }
1132
1133    impl<B> quic::BidiStream<B> for BidiStream<B>
1134    where
1135        B: Buf,
1136    {
1137        type RecvStream = RecvStream;
1138        type SendStream = SendStream<B>;
1139
1140        fn split(self) -> (Self::SendStream, Self::RecvStream) {
1141            (self.send, self.recv)
1142        }
1143    }
1144
1145    impl<B> quic::RecvStream for BidiStream<B>
1146    where
1147        B: Buf,
1148    {
1149        type Buf = Bytes;
1150
1151        fn poll_data(
1152            &mut self,
1153            cx: &mut Context<'_>,
1154        ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1155            self.recv.poll_data(cx)
1156        }
1157
1158        fn stop_sending(&mut self, error_code: u64) {
1159            self.recv.stop_sending(error_code)
1160        }
1161
1162        fn recv_id(&self) -> quic::StreamId {
1163            self.recv.recv_id()
1164        }
1165    }
1166
1167    impl<B> quic::SendStream<B> for BidiStream<B>
1168    where
1169        B: Buf,
1170    {
1171        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1172            self.send.poll_ready(cx)
1173        }
1174
1175        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1176            self.send.send_data(data)
1177        }
1178
1179        fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1180            self.send.poll_finish(cx)
1181        }
1182
1183        fn reset(&mut self, reset_code: u64) {
1184            self.send.reset(reset_code)
1185        }
1186
1187        fn send_id(&self) -> quic::StreamId {
1188            self.send.send_id()
1189        }
1190    }
1191
1192    impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1193    where
1194        B: Buf,
1195    {
1196        fn poll_send<D: Buf>(
1197            &mut self,
1198            cx: &mut Context<'_>,
1199            buf: &mut D,
1200        ) -> Poll<Result<usize, StreamErrorIncoming>> {
1201            self.send.poll_send(cx, buf)
1202        }
1203    }
1204
1205    /// Stream opener.
1206    #[derive(Clone)]
1207    pub struct OpenStreams(Connection);
1208
1209    impl<B> quic::OpenStreams<B> for OpenStreams
1210    where
1211        B: Buf,
1212    {
1213        type BidiStream = BidiStream<B>;
1214        type SendStream = SendStream<B>;
1215
1216        fn poll_open_bidi(
1217            &mut self,
1218            cx: &mut Context<'_>,
1219        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1220            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1221            Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1222        }
1223
1224        fn poll_open_send(
1225            &mut self,
1226            cx: &mut Context<'_>,
1227        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1228            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1229            Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1230        }
1231
1232        fn close(&mut self, code: Code, reason: &[u8]) {
1233            self.0
1234                .close(code.value().try_into().expect("invalid code"), reason)
1235        }
1236    }
1237
1238    impl<B> quic::OpenStreams<B> for Connection
1239    where
1240        B: Buf,
1241    {
1242        type BidiStream = BidiStream<B>;
1243        type SendStream = SendStream<B>;
1244
1245        fn poll_open_bidi(
1246            &mut self,
1247            cx: &mut Context<'_>,
1248        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1249            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1250            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1251        }
1252
1253        fn poll_open_send(
1254            &mut self,
1255            cx: &mut Context<'_>,
1256        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1257            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1258            Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1259        }
1260
1261        fn close(&mut self, code: Code, reason: &[u8]) {
1262            Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1263        }
1264    }
1265
1266    impl<B> quic::Connection<B> for Connection
1267    where
1268        B: Buf,
1269    {
1270        type OpenStreams = OpenStreams;
1271        type RecvStream = RecvStream;
1272
1273        fn poll_accept_recv(
1274            &mut self,
1275            cx: &mut std::task::Context<'_>,
1276        ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1277            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1278            Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1279        }
1280
1281        fn poll_accept_bidi(
1282            &mut self,
1283            cx: &mut std::task::Context<'_>,
1284        ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1285            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1286            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1287        }
1288
1289        fn opener(&self) -> Self::OpenStreams {
1290            OpenStreams(self.clone())
1291        }
1292    }
1293}