Skip to main content

compio_quic/
connection.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    net::{IpAddr, SocketAddr},
5    pin::{Pin, pin},
6    task::{Context, Poll, Waker},
7    time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15    FutureExt, StreamExt,
16    future::{self, Fuse, FusedFuture, LocalBoxFuture},
17    select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22    ConnectionHandle, ConnectionStats, Dir, EndpointEvent, Side, StreamEvent, StreamId, VarInt,
23    congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29    RecvStream, SendStream, Socket,
30    sync::{
31        mutex_blocking::{Mutex, MutexGuard},
32        shared::Shared,
33    },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38    Close(VarInt, Bytes),
39    Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44    pub(crate) conn: quinn_proto::Connection,
45    pub(crate) error: Option<ConnectionError>,
46    connected: bool,
47    worker: Option<JoinHandle<()>>,
48    poller: Option<Waker>,
49    on_connected: Option<Waker>,
50    on_handshake_data: Option<Waker>,
51    datagram_received: VecDeque<Waker>,
52    datagrams_unblocked: VecDeque<Waker>,
53    stream_opened: [VecDeque<Waker>; 2],
54    stream_available: [VecDeque<Waker>; 2],
55    pub(crate) writable: HashMap<StreamId, Waker>,
56    pub(crate) readable: HashMap<StreamId, Waker>,
57    pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61    fn terminate(&mut self, reason: ConnectionError) {
62        self.error = Some(reason);
63        self.connected = false;
64
65        if let Some(waker) = self.on_handshake_data.take() {
66            waker.wake()
67        }
68        if let Some(waker) = self.on_connected.take() {
69            waker.wake()
70        }
71        self.datagram_received.drain(..).for_each(Waker::wake);
72        self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73        for e in &mut self.stream_opened {
74            e.drain(..).for_each(Waker::wake);
75        }
76        for e in &mut self.stream_available {
77            e.drain(..).for_each(Waker::wake);
78        }
79        wake_all_streams(&mut self.writable);
80        wake_all_streams(&mut self.readable);
81        wake_all_streams(&mut self.stopped);
82    }
83
84    fn close(&mut self, error_code: VarInt, reason: Bytes) {
85        self.conn.close(Instant::now(), error_code, reason);
86        self.terminate(ConnectionError::LocallyClosed);
87        self.wake();
88    }
89
90    pub(crate) fn wake(&mut self) {
91        if let Some(waker) = self.poller.take() {
92            waker.wake()
93        }
94    }
95
96    #[cfg(rustls)]
97    fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98        self.conn
99            .crypto_session()
100            .handshake_data()
101            .map(|data| data.downcast::<HandshakeData>().unwrap())
102    }
103
104    pub(crate) fn check_0rtt(&self) -> bool {
105        self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106    }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110    if let Some(waker) = wakers.remove(&stream) {
111        waker.wake();
112    }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116    wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121    state: Mutex<ConnectionState>,
122    handle: ConnectionHandle,
123    socket: Socket,
124    events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125    events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129    if Shared::strong_count(this) == 2 {
130        this.state().close(0u32.into(), Bytes::new())
131    }
132}
133
134impl ConnectionInner {
135    fn new(
136        handle: ConnectionHandle,
137        conn: quinn_proto::Connection,
138        socket: Socket,
139        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140        events_rx: Receiver<ConnectionEvent>,
141    ) -> Self {
142        Self {
143            state: Mutex::new(ConnectionState {
144                conn,
145                connected: false,
146                error: None,
147                worker: None,
148                poller: None,
149                on_connected: None,
150                on_handshake_data: None,
151                datagram_received: VecDeque::new(),
152                datagrams_unblocked: VecDeque::new(),
153                stream_opened: [VecDeque::new(), VecDeque::new()],
154                stream_available: [VecDeque::new(), VecDeque::new()],
155                writable: HashMap::default(),
156                readable: HashMap::default(),
157                stopped: HashMap::default(),
158            }),
159            handle,
160            socket,
161            events_tx,
162            events_rx,
163        }
164    }
165
166    #[inline]
167    pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168        self.state.lock()
169    }
170
171    #[inline]
172    pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173        let state = self.state();
174        if let Some(error) = &state.error {
175            Err(error.clone())
176        } else {
177            Ok(state)
178        }
179    }
180
181    async fn run(&self) {
182        let mut poller = stream::poll_fn(|cx| {
183            let mut state = self.state();
184            let ready = state.poller.is_none();
185            match &state.poller {
186                Some(waker) if waker.will_wake(cx.waker()) => {}
187                _ => state.poller = Some(cx.waker().clone()),
188            };
189            if ready {
190                Poll::Ready(Some(()))
191            } else {
192                Poll::Pending
193            }
194        })
195        .fuse();
196
197        let mut timer = Timer::new();
198        let mut event_stream = self.events_rx.stream().ready_chunks(100);
199        let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200        let mut transmit_fut = pin!(Fuse::terminated());
201
202        loop {
203            let mut state = select! {
204                _ = poller.select_next_some() => self.state(),
205                _ = timer => {
206                    timer.reset(None);
207                    let mut state = self.state();
208                    state.conn.handle_timeout(Instant::now());
209                    state
210                }
211                events = event_stream.select_next_some() => {
212                    let mut state = self.state();
213                    for event in events {
214                        match event {
215                            ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216                            ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217                        }
218                    }
219                    state
220                },
221                buf = transmit_fut => {
222                    // The following line is required to avoid "type annotations needed" error
223                    let mut buf: Vec<_> = buf;
224                    buf.clear();
225                    send_buf = Some(buf);
226                    self.state()
227                },
228            };
229
230            if let Some(mut buf) = send_buf.take() {
231                if let Some(transmit) = state.conn.poll_transmit(
232                    Instant::now(),
233                    self.socket.max_gso_segments(),
234                    &mut buf,
235                ) {
236                    transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237                } else {
238                    send_buf = Some(buf);
239                }
240            }
241
242            timer.reset(state.conn.poll_timeout());
243
244            while let Some(event) = state.conn.poll_endpoint_events() {
245                let _ = self.events_tx.send((self.handle, event));
246            }
247
248            while let Some(event) = state.conn.poll() {
249                use quinn_proto::Event::*;
250                match event {
251                    HandshakeDataReady => {
252                        if let Some(waker) = state.on_handshake_data.take() {
253                            waker.wake()
254                        }
255                    }
256                    Connected => {
257                        state.connected = true;
258                        if let Some(waker) = state.on_connected.take() {
259                            waker.wake()
260                        }
261                        if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262                            // Wake up rejected 0-RTT streams so they can fail immediately with
263                            // `ZeroRttRejected` errors.
264                            wake_all_streams(&mut state.writable);
265                            wake_all_streams(&mut state.readable);
266                            wake_all_streams(&mut state.stopped);
267                        }
268                    }
269                    ConnectionLost { reason } => state.terminate(reason.into()),
270                    Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271                    Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272                    Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273                    Stream(StreamEvent::Stopped { id, .. }) => {
274                        wake_stream(id, &mut state.stopped);
275                        wake_stream(id, &mut state.writable);
276                    }
277                    Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278                        .drain(..)
279                        .for_each(Waker::wake),
280                    Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281                        .drain(..)
282                        .for_each(Waker::wake),
283                    DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284                    DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285                }
286            }
287
288            if state.conn.is_drained() {
289                break;
290            }
291        }
292    }
293}
294
295macro_rules! conn_fn {
296    () => {
297        /// The side of the connection (client or server)
298        pub fn side(&self) -> Side {
299            self.0.state().conn.side()
300        }
301
302        /// The local IP address which was used when the peer established
303        /// the connection.
304        ///
305        /// This can be different from the address the endpoint is bound to, in case
306        /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
307        ///
308        /// This will return `None` for clients, or when the platform does not
309        /// expose this information.
310        pub fn local_ip(&self) -> Option<IpAddr> {
311            self.0.state().conn.local_ip()
312        }
313
314        /// The peer's UDP address.
315        ///
316        /// Will panic if called after `poll` has returned `Ready`.
317        pub fn remote_address(&self) -> SocketAddr {
318            self.0.state().conn.remote_address()
319        }
320
321        /// Current best estimate of this connection's latency (round-trip-time).
322        pub fn rtt(&self) -> Duration {
323            self.0.state().conn.rtt()
324        }
325
326        /// Connection statistics.
327        pub fn stats(&self) -> ConnectionStats {
328            self.0.state().conn.stats()
329        }
330
331        /// Current state of the congestion control algorithm. (For debugging
332        /// purposes)
333        pub fn congestion_state(&self) -> Box<dyn Controller> {
334            self.0.state().conn.congestion_state().clone_box()
335        }
336
337        /// Cryptographic identity of the peer.
338        pub fn peer_identity(
339            &self,
340        ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
341            self.0
342                .state()
343                .conn
344                .crypto_session()
345                .peer_identity()
346                .map(|v| v.downcast().unwrap())
347        }
348
349        /// A stable identifier for this connection
350        ///
351        /// Peer addresses and connection IDs can change, but this value will remain
352        /// fixed for the lifetime of the connection.
353        pub fn stable_id(&self) -> usize {
354            Shared::as_ptr(&self.0) as usize
355        }
356
357        /// Derive keying material from this connection's TLS session secrets.
358        ///
359        /// When both peers call this method with the same `label` and `context`
360        /// arguments and `output` buffers of equal length, they will get the
361        /// same sequence of bytes in `output`. These bytes are cryptographically
362        /// strong and pseudorandom, and are suitable for use as keying material.
363        ///
364        /// This function fails if called with an empty `output` or called prior to
365        /// the handshake completing.
366        ///
367        /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information.
368        pub fn export_keying_material(
369            &self,
370            output: &mut [u8],
371            label: &[u8],
372            context: &[u8],
373        ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
374            self.0
375                .state()
376                .conn
377                .crypto_session()
378                .export_keying_material(output, label, context)
379        }
380    };
381}
382
383/// In-progress connection attempt future
384#[derive(Debug)]
385#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
386pub struct Connecting(Shared<ConnectionInner>);
387
388impl Connecting {
389    conn_fn!();
390
391    pub(crate) fn new(
392        handle: ConnectionHandle,
393        conn: quinn_proto::Connection,
394        socket: Socket,
395        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
396        events_rx: Receiver<ConnectionEvent>,
397    ) -> Self {
398        let inner = Shared::new(ConnectionInner::new(
399            handle, conn, socket, events_tx, events_rx,
400        ));
401        let worker = compio_runtime::spawn({
402            let inner = inner.clone();
403            async move { inner.run().await }.in_current_span()
404        });
405        inner.state().worker = Some(worker);
406        Self(inner)
407    }
408
409    /// Parameters negotiated during the handshake.
410    #[cfg(rustls)]
411    pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
412        future::poll_fn(|cx| {
413            let mut state = self.0.try_state()?;
414            if let Some(data) = state.handshake_data() {
415                return Poll::Ready(Ok(data));
416            }
417
418            match &state.on_handshake_data {
419                Some(waker) if waker.will_wake(cx.waker()) => {}
420                _ => state.on_handshake_data = Some(cx.waker().clone()),
421            }
422
423            Poll::Pending
424        })
425        .await
426    }
427
428    /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened
429    /// security.
430    ///
431    /// Returns `Ok` immediately if the local endpoint is able to attempt
432    /// sending 0/0.5-RTT data. If so, the returned [`Connection`] can be used
433    /// to send application data without waiting for the rest of the handshake
434    /// to complete, at the cost of weakened cryptographic security guarantees.
435    /// The [`Connection::accepted_0rtt`] method resolves when the handshake
436    /// does complete, at which point subsequently opened streams and written
437    /// data will have full cryptographic protection.
438    ///
439    /// ## Outgoing
440    ///
441    /// For outgoing connections, the initial attempt to convert to a
442    /// [`Connection`] which sends 0-RTT data will proceed if the
443    /// [`crypto::ClientConfig`][crate::crypto::ClientConfig] attempts to resume
444    /// a previous TLS session. However, **the remote endpoint may not actually
445    /// _accept_ the 0-RTT data**--yet still accept the connection attempt in
446    /// general. This possibility is conveyed through the
447    /// [`Connection::accepted_0rtt`] method--when the handshake completes, it
448    /// resolves to true if the 0-RTT data was accepted and false if it was
449    /// rejected. If it was rejected, the existence of streams opened and other
450    /// application data sent prior to the handshake completing will not be
451    /// conveyed to the remote application, and local operations on them will
452    /// return `ZeroRttRejected` errors.
453    ///
454    /// A server may reject 0-RTT data at its discretion, but accepting 0-RTT
455    /// data requires the relevant resumption state to be stored in the server,
456    /// which servers may limit or lose for various reasons including not
457    /// persisting resumption state across server restarts.
458    ///
459    /// ## Incoming
460    ///
461    /// For incoming connections, conversion to 0.5-RTT will always fully
462    /// succeed. `into_0rtt` will always return `Ok` and
463    /// [`Connection::accepted_0rtt`] will always resolve to true.
464    ///
465    /// ## Security
466    ///
467    /// On outgoing connections, this enables transmission of 0-RTT data, which
468    /// is vulnerable to replay attacks, and should therefore never invoke
469    /// non-idempotent operations.
470    ///
471    /// On incoming connections, this enables transmission of 0.5-RTT data,
472    /// which may be sent before TLS client authentication has occurred, and
473    /// should therefore not be used to send data for which client
474    /// authentication is being used.
475    pub fn into_0rtt(self) -> Result<Connection, Self> {
476        let is_ok = {
477            let state = self.0.state();
478            state.conn.has_0rtt() || state.conn.side().is_server()
479        };
480        if is_ok {
481            Ok(Connection(self.0.clone()))
482        } else {
483            Err(self)
484        }
485    }
486}
487
488impl Future for Connecting {
489    type Output = Result<Connection, ConnectionError>;
490
491    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
492        let mut state = self.0.try_state()?;
493
494        if state.connected {
495            return Poll::Ready(Ok(Connection(self.0.clone())));
496        }
497
498        match &state.on_connected {
499            Some(waker) if waker.will_wake(cx.waker()) => {}
500            _ => state.on_connected = Some(cx.waker().clone()),
501        }
502
503        Poll::Pending
504    }
505}
506
507impl Drop for Connecting {
508    fn drop(&mut self) {
509        implicit_close(&self.0)
510    }
511}
512
513/// A QUIC connection.
514#[derive(Debug, Clone)]
515pub struct Connection(Shared<ConnectionInner>);
516
517impl Connection {
518    conn_fn!();
519
520    /// Update traffic keys spontaneously
521    ///
522    /// This primarily exists for testing purposes.
523    pub fn force_key_update(&self) {
524        self.0.state().conn.force_key_update()
525    }
526
527    /// Parameters negotiated during the handshake.
528    #[cfg(rustls)]
529    pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
530        Ok(self.0.try_state()?.handshake_data().unwrap())
531    }
532
533    /// Compute the maximum size of datagrams that may be passed to
534    /// [`send_datagram()`](Self::send_datagram).
535    ///
536    /// Returns `None` if datagrams are unsupported by the peer or disabled
537    /// locally.
538    ///
539    /// This may change over the lifetime of a connection according to variation
540    /// in the path MTU estimate. The peer can also enforce an arbitrarily small
541    /// fixed limit, but if the peer's limit is large this is guaranteed to be a
542    /// little over a kilobyte at minimum.
543    ///
544    /// Not necessarily the maximum size of received datagrams.
545    pub fn max_datagram_size(&self) -> Option<usize> {
546        self.0.state().conn.datagrams().max_size()
547    }
548
549    /// Bytes available in the outgoing datagram buffer.
550    ///
551    /// When greater than zero, calling [`send_datagram()`](Self::send_datagram)
552    /// with a datagram of at most this size is guaranteed not to cause older
553    /// datagrams to be dropped.
554    pub fn datagram_send_buffer_space(&self) -> usize {
555        self.0.state().conn.datagrams().send_buffer_space()
556    }
557
558    /// Modify the number of remotely initiated unidirectional streams that may
559    /// be concurrently open.
560    ///
561    /// No streams may be opened by the peer unless fewer than `count` are
562    /// already open. Large `count`s increase both minimum and worst-case
563    /// memory consumption.
564    pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
565        let mut state = self.0.state();
566        state.conn.set_max_concurrent_streams(Dir::Uni, count);
567        // May need to send MAX_STREAMS to make progress
568        state.wake();
569    }
570
571    /// See [`quinn_proto::TransportConfig::send_window()`]
572    pub fn set_send_window(&self, send_window: u64) {
573        let mut state = self.0.state();
574        state.conn.set_send_window(send_window);
575        state.wake();
576    }
577
578    /// See [`quinn_proto::TransportConfig::receive_window()`]
579    pub fn set_receive_window(&self, receive_window: VarInt) {
580        let mut state = self.0.state();
581        state.conn.set_receive_window(receive_window);
582        state.wake();
583    }
584
585    /// Modify the number of remotely initiated bidirectional streams that may
586    /// be concurrently open.
587    ///
588    /// No streams may be opened by the peer unless fewer than `count` are
589    /// already open. Large `count`s increase both minimum and worst-case
590    /// memory consumption.
591    pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
592        let mut state = self.0.state();
593        state.conn.set_max_concurrent_streams(Dir::Bi, count);
594        // May need to send MAX_STREAMS to make progress
595        state.wake();
596    }
597
598    /// Close the connection immediately.
599    ///
600    /// Pending operations will fail immediately with
601    /// [`ConnectionError::LocallyClosed`]. No more data is sent to the peer
602    /// and the peer may drop buffered data upon receiving
603    /// the CONNECTION_CLOSE frame.
604    ///
605    /// `error_code` and `reason` are not interpreted, and are provided directly
606    /// to the peer.
607    ///
608    /// `reason` will be truncated to fit in a single packet with overhead; to
609    /// improve odds that it is preserved in full, it should be kept under
610    /// 1KiB.
611    ///
612    /// # Gracefully closing a connection
613    ///
614    /// Only the peer last receiving application data can be certain that all
615    /// data is delivered. The only reliable action it can then take is to
616    /// close the connection, potentially with a custom error code. The
617    /// delivery of the final CONNECTION_CLOSE frame is very likely if both
618    /// endpoints stay online long enough, and [`Endpoint::shutdown()`] can
619    /// be used to provide sufficient time. Otherwise, the remote peer will
620    /// time out the connection, provided that the idle timeout is not
621    /// disabled.
622    ///
623    /// The sending side can not guarantee all stream data is delivered to the
624    /// remote application. It only knows the data is delivered to the QUIC
625    /// stack of the remote endpoint. Once the local side sends a
626    /// CONNECTION_CLOSE frame in response to calling [`close()`] the remote
627    /// endpoint may drop any data it received but is as yet undelivered to
628    /// the application, including data that was acknowledged as received to
629    /// the local endpoint.
630    ///
631    /// [`ConnectionError::LocallyClosed`]: ConnectionError::LocallyClosed
632    /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown
633    /// [`close()`]: Connection::close
634    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
635        self.0
636            .state()
637            .close(error_code, Bytes::copy_from_slice(reason));
638    }
639
640    /// Wait for the connection to be closed for any reason.
641    pub async fn closed(&self) -> ConnectionError {
642        let worker = self.0.state().worker.take();
643        if let Some(worker) = worker {
644            let _ = worker.await;
645        }
646
647        self.0.try_state().unwrap_err()
648    }
649
650    /// If the connection is closed, the reason why.
651    ///
652    /// Returns `None` if the connection is still open.
653    pub fn close_reason(&self) -> Option<ConnectionError> {
654        self.0.try_state().err()
655    }
656
657    fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
658        let mut state = self.0.try_state()?;
659        if let Some(bytes) = state.conn.datagrams().recv() {
660            return Poll::Ready(Ok(bytes));
661        }
662        state.datagram_received.push_back(cx.waker().clone());
663        Poll::Pending
664    }
665
666    /// Try to receive an application datagram. Returns None if no datagram is
667    /// available.
668    pub fn try_recv_datagram(&self) -> Result<Option<Bytes>, ConnectionError> {
669        let mut state = self.0.try_state()?;
670        Ok(state.conn.datagrams().recv())
671    }
672
673    /// Receive an application datagram.
674    pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
675        future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
676    }
677
678    fn try_send_datagram(
679        &self,
680        cx: Option<&mut Context>,
681        data: Bytes,
682    ) -> Result<(), Result<SendDatagramError, Bytes>> {
683        use quinn_proto::SendDatagramError::*;
684        let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
685        state
686            .conn
687            .datagrams()
688            .send(data, cx.is_none())
689            .map_err(|err| match err {
690                UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
691                Disabled => Ok(SendDatagramError::Disabled),
692                TooLarge => Ok(SendDatagramError::TooLarge),
693                Blocked(data) => {
694                    state
695                        .datagrams_unblocked
696                        .push_back(cx.unwrap().waker().clone());
697                    Err(data)
698                }
699            })?;
700        state.wake();
701        Ok(())
702    }
703
704    /// Transmit `data` as an unreliable, unordered application datagram.
705    ///
706    /// Application datagrams are a low-level primitive. They may be lost or
707    /// delivered out of order, and `data` must both fit inside a single
708    /// QUIC packet and be smaller than the maximum dictated by the peer.
709    pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
710        self.try_send_datagram(None, data).map_err(Result::unwrap)
711    }
712
713    /// Transmit `data` as an unreliable, unordered application datagram.
714    ///
715    /// Unlike [`send_datagram()`], this method will wait for buffer space
716    /// during congestion conditions, which effectively prioritizes old
717    /// datagrams over new datagrams.
718    ///
719    /// See [`send_datagram()`] for details.
720    ///
721    /// [`send_datagram()`]: Connection::send_datagram
722    pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
723        let mut data = Some(data);
724        future::poll_fn(
725            |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
726                Ok(()) => Poll::Ready(Ok(())),
727                Err(Ok(e)) => Poll::Ready(Err(e)),
728                Err(Err(b)) => {
729                    data.replace(b);
730                    Poll::Pending
731                }
732            },
733        )
734        .await
735    }
736
737    fn poll_open_stream(
738        &self,
739        cx: Option<&mut Context>,
740        dir: Dir,
741    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
742        let mut state = self.0.try_state()?;
743        if let Some(stream) = state.conn.streams().open(dir) {
744            Poll::Ready(Ok((
745                stream,
746                state.conn.side().is_client() && state.conn.is_handshaking(),
747            )))
748        } else {
749            if let Some(cx) = cx {
750                state.stream_available[dir as usize].push_back(cx.waker().clone());
751            }
752            Poll::Pending
753        }
754    }
755
756    /// Initiate a new outgoing unidirectional stream.
757    ///
758    /// Streams are cheap and instantaneous to open. As a consequence, the peer
759    /// won't be notified that a stream has been opened until the stream is
760    /// actually used.
761    pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
762        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
763            Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
764        } else {
765            Err(OpenStreamError::StreamsExhausted)
766        }
767    }
768
769    /// Initiate a new outgoing unidirectional stream.
770    ///
771    /// Unlike [`open_uni()`], this method will wait for the connection to allow
772    /// a new stream to be opened.
773    ///
774    /// See [`open_uni()`] for details.
775    ///
776    /// [`open_uni()`]: crate::Connection::open_uni
777    pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
778        let (stream, is_0rtt) =
779            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
780        Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
781    }
782
783    /// Initiate a new outgoing bidirectional stream.
784    ///
785    /// Streams are cheap and instantaneous to open. As a consequence, the peer
786    /// won't be notified that a stream has been opened until the stream is
787    /// actually used.
788    pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
789        if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
790            Ok((
791                SendStream::new(self.0.clone(), stream, is_0rtt),
792                RecvStream::new(self.0.clone(), stream, is_0rtt),
793            ))
794        } else {
795            Err(OpenStreamError::StreamsExhausted)
796        }
797    }
798
799    /// Initiate a new outgoing bidirectional stream.
800    ///
801    /// Unlike [`open_bi()`], this method will wait for the connection to allow
802    /// a new stream to be opened.
803    ///
804    /// See [`open_bi()`] for details.
805    ///
806    /// [`open_bi()`]: crate::Connection::open_bi
807    pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
808        let (stream, is_0rtt) =
809            future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
810        Ok((
811            SendStream::new(self.0.clone(), stream, is_0rtt),
812            RecvStream::new(self.0.clone(), stream, is_0rtt),
813        ))
814    }
815
816    fn poll_accept_stream(
817        &self,
818        cx: &mut Context,
819        dir: Dir,
820    ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
821        let mut state = self.0.try_state()?;
822        if let Some(stream) = state.conn.streams().accept(dir) {
823            state.wake();
824            Poll::Ready(Ok((stream, state.conn.is_handshaking())))
825        } else {
826            state.stream_opened[dir as usize].push_back(cx.waker().clone());
827            Poll::Pending
828        }
829    }
830
831    /// Accept the next incoming uni-directional stream
832    pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
833        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
834        Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
835    }
836
837    /// Accept the next incoming bidirectional stream
838    ///
839    /// **Important Note**: The `Connection` that calls [`open_bi()`] must write
840    /// to its [`SendStream`] before the other `Connection` is able to
841    /// `accept_bi()`. Calling [`open_bi()`] then waiting on the [`RecvStream`]
842    /// without writing anything to [`SendStream`] will never succeed.
843    ///
844    /// [`accept_bi()`]: crate::Connection::accept_bi
845    /// [`open_bi()`]: crate::Connection::open_bi
846    /// [`SendStream`]: crate::SendStream
847    /// [`RecvStream`]: crate::RecvStream
848    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
849        let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
850        Ok((
851            SendStream::new(self.0.clone(), stream, is_0rtt),
852            RecvStream::new(self.0.clone(), stream, is_0rtt),
853        ))
854    }
855
856    /// Wait for the connection to be fully established.
857    ///
858    /// For clients, the resulting value indicates if 0-RTT was accepted. For
859    /// servers, the resulting value is meaningless.
860    pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
861        future::poll_fn(|cx| {
862            let mut state = self.0.try_state()?;
863
864            if state.connected {
865                return Poll::Ready(Ok(state.conn.accepted_0rtt()));
866            }
867
868            match &state.on_connected {
869                Some(waker) if waker.will_wake(cx.waker()) => {}
870                _ => state.on_connected = Some(cx.waker().clone()),
871            }
872
873            Poll::Pending
874        })
875        .await
876    }
877}
878
879impl PartialEq for Connection {
880    fn eq(&self, other: &Self) -> bool {
881        Shared::ptr_eq(&self.0, &other.0)
882    }
883}
884
885impl Eq for Connection {}
886
887impl Drop for Connection {
888    fn drop(&mut self) {
889        implicit_close(&self.0)
890    }
891}
892
893struct Timer {
894    deadline: Option<Instant>,
895    fut: Fuse<LocalBoxFuture<'static, ()>>,
896}
897
898impl Timer {
899    fn new() -> Self {
900        Self {
901            deadline: None,
902            fut: Fuse::terminated(),
903        }
904    }
905
906    fn reset(&mut self, deadline: Option<Instant>) {
907        if let Some(deadline) = deadline {
908            if self.deadline.is_none() || self.deadline != Some(deadline) {
909                self.fut = compio_runtime::time::sleep_until(deadline)
910                    .boxed_local()
911                    .fuse();
912            }
913        } else {
914            self.fut = Fuse::terminated();
915        }
916        self.deadline = deadline;
917    }
918}
919
920impl Future for Timer {
921    type Output = ();
922
923    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
924        self.fut.poll_unpin(cx)
925    }
926}
927
928impl FusedFuture for Timer {
929    fn is_terminated(&self) -> bool {
930        self.fut.is_terminated()
931    }
932}
933
934/// Reasons why a connection might be lost
935#[derive(Debug, Error, Clone, PartialEq, Eq)]
936pub enum ConnectionError {
937    /// The peer doesn't implement any supported version
938    #[error("peer doesn't implement any supported version")]
939    VersionMismatch,
940    /// The peer violated the QUIC specification as understood by this
941    /// implementation
942    #[error(transparent)]
943    TransportError(#[from] quinn_proto::TransportError),
944    /// The peer's QUIC stack aborted the connection automatically
945    #[error("aborted by peer: {0}")]
946    ConnectionClosed(quinn_proto::ConnectionClose),
947    /// The peer closed the connection
948    #[error("closed by peer: {0}")]
949    ApplicationClosed(quinn_proto::ApplicationClose),
950    /// The peer is unable to continue processing this connection, usually due
951    /// to having restarted
952    #[error("reset by peer")]
953    Reset,
954    /// Communication with the peer has lapsed for longer than the negotiated
955    /// idle timeout
956    ///
957    /// If neither side is sending keep-alives, a connection will time out after
958    /// a long enough idle period even if the peer is still reachable. See
959    /// also [`TransportConfig::max_idle_timeout()`](quinn_proto::TransportConfig::max_idle_timeout())
960    /// and [`TransportConfig::keep_alive_interval()`](quinn_proto::TransportConfig::keep_alive_interval()).
961    #[error("timed out")]
962    TimedOut,
963    /// The local application closed the connection
964    #[error("closed")]
965    LocallyClosed,
966    /// The connection could not be created because not enough of the CID space
967    /// is available
968    ///
969    /// Try using longer connection IDs.
970    #[error("CIDs exhausted")]
971    CidsExhausted,
972}
973
974impl From<quinn_proto::ConnectionError> for ConnectionError {
975    fn from(value: quinn_proto::ConnectionError) -> Self {
976        use quinn_proto::ConnectionError::*;
977
978        match value {
979            VersionMismatch => ConnectionError::VersionMismatch,
980            TransportError(e) => ConnectionError::TransportError(e),
981            ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
982            ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
983            Reset => ConnectionError::Reset,
984            TimedOut => ConnectionError::TimedOut,
985            LocallyClosed => ConnectionError::LocallyClosed,
986            CidsExhausted => ConnectionError::CidsExhausted,
987        }
988    }
989}
990
991/// Errors that can arise when sending a datagram
992#[derive(Debug, Error, Clone, Eq, PartialEq)]
993pub enum SendDatagramError {
994    /// The peer does not support receiving datagram frames
995    #[error("datagrams not supported by peer")]
996    UnsupportedByPeer,
997    /// Datagram support is disabled locally
998    #[error("datagram support disabled")]
999    Disabled,
1000    /// The datagram is larger than the connection can currently accommodate
1001    ///
1002    /// Indicates that the path MTU minus overhead or the limit advertised by
1003    /// the peer has been exceeded.
1004    #[error("datagram too large")]
1005    TooLarge,
1006    /// The connection was lost
1007    #[error("connection lost")]
1008    ConnectionLost(#[from] ConnectionError),
1009}
1010
1011/// Errors that can arise when trying to open a stream
1012#[derive(Debug, Error, Clone, Eq, PartialEq)]
1013pub enum OpenStreamError {
1014    /// The connection was lost
1015    #[error("connection lost")]
1016    ConnectionLost(#[from] ConnectionError),
1017    /// The streams in the given direction are currently exhausted
1018    #[error("streams exhausted")]
1019    StreamsExhausted,
1020}
1021
1022#[cfg(feature = "h3")]
1023pub(crate) mod h3_impl {
1024    use std::sync::Arc;
1025
1026    use compio_buf::bytes::Buf;
1027    use futures_util::ready;
1028    use h3::{
1029        error::Code,
1030        quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
1031    };
1032    use h3_datagram::{
1033        datagram::EncodedDatagram,
1034        quic_traits::{
1035            DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1036        },
1037    };
1038
1039    use super::*;
1040    use crate::send_stream::h3_impl::SendStream;
1041
1042    impl From<ConnectionError> for ConnectionErrorIncoming {
1043        fn from(e: ConnectionError) -> Self {
1044            use ConnectionError::*;
1045            match e {
1046                ApplicationClosed(e) => Self::ApplicationClose {
1047                    error_code: e.error_code.into_inner(),
1048                },
1049                TimedOut => Self::Timeout,
1050
1051                e => Self::Undefined(Arc::new(e)),
1052            }
1053        }
1054    }
1055
1056    impl From<ConnectionError> for StreamErrorIncoming {
1057        fn from(e: ConnectionError) -> Self {
1058            Self::ConnectionErrorIncoming {
1059                connection_error: e.into(),
1060            }
1061        }
1062    }
1063
1064    impl From<SendDatagramError> for SendDatagramErrorIncoming {
1065        fn from(e: SendDatagramError) -> Self {
1066            use SendDatagramError::*;
1067            match e {
1068                UnsupportedByPeer | Disabled => Self::NotAvailable,
1069                TooLarge => Self::TooLarge,
1070                ConnectionLost(e) => Self::ConnectionError(e.into()),
1071            }
1072        }
1073    }
1074
1075    impl<B> SendDatagram<B> for Connection
1076    where
1077        B: Buf,
1078    {
1079        fn send_datagram<T: Into<EncodedDatagram<B>>>(
1080            &mut self,
1081            data: T,
1082        ) -> Result<(), SendDatagramErrorIncoming> {
1083            let mut buf: EncodedDatagram<B> = data.into();
1084            let buf = buf.copy_to_bytes(buf.remaining());
1085            Ok(Connection::send_datagram(self, buf)?)
1086        }
1087    }
1088
1089    impl RecvDatagram for Connection {
1090        type Buffer = Bytes;
1091
1092        fn poll_incoming_datagram(
1093            &mut self,
1094            cx: &mut core::task::Context<'_>,
1095        ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1096            Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1097        }
1098    }
1099
1100    impl<B: Buf> DatagramConnectionExt<B> for Connection {
1101        type RecvDatagramHandler = Self;
1102        type SendDatagramHandler = Self;
1103
1104        fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1105            self.clone()
1106        }
1107
1108        fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1109            self.clone()
1110        }
1111    }
1112
1113    /// Bidirectional stream.
1114    pub struct BidiStream<B> {
1115        send: SendStream<B>,
1116        recv: RecvStream,
1117    }
1118
1119    impl<B> BidiStream<B> {
1120        pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1121            Self {
1122                send: SendStream::new(conn.clone(), stream, is_0rtt),
1123                recv: RecvStream::new(conn, stream, is_0rtt),
1124            }
1125        }
1126    }
1127
1128    impl<B> quic::BidiStream<B> for BidiStream<B>
1129    where
1130        B: Buf,
1131    {
1132        type RecvStream = RecvStream;
1133        type SendStream = SendStream<B>;
1134
1135        fn split(self) -> (Self::SendStream, Self::RecvStream) {
1136            (self.send, self.recv)
1137        }
1138    }
1139
1140    impl<B> quic::RecvStream for BidiStream<B>
1141    where
1142        B: Buf,
1143    {
1144        type Buf = Bytes;
1145
1146        fn poll_data(
1147            &mut self,
1148            cx: &mut Context<'_>,
1149        ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1150            self.recv.poll_data(cx)
1151        }
1152
1153        fn stop_sending(&mut self, error_code: u64) {
1154            self.recv.stop_sending(error_code)
1155        }
1156
1157        fn recv_id(&self) -> quic::StreamId {
1158            self.recv.recv_id()
1159        }
1160    }
1161
1162    impl<B> quic::SendStream<B> for BidiStream<B>
1163    where
1164        B: Buf,
1165    {
1166        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1167            self.send.poll_ready(cx)
1168        }
1169
1170        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1171            self.send.send_data(data)
1172        }
1173
1174        fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1175            self.send.poll_finish(cx)
1176        }
1177
1178        fn reset(&mut self, reset_code: u64) {
1179            self.send.reset(reset_code)
1180        }
1181
1182        fn send_id(&self) -> quic::StreamId {
1183            self.send.send_id()
1184        }
1185    }
1186
1187    impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1188    where
1189        B: Buf,
1190    {
1191        fn poll_send<D: Buf>(
1192            &mut self,
1193            cx: &mut Context<'_>,
1194            buf: &mut D,
1195        ) -> Poll<Result<usize, StreamErrorIncoming>> {
1196            self.send.poll_send(cx, buf)
1197        }
1198    }
1199
1200    /// Stream opener.
1201    #[derive(Clone)]
1202    pub struct OpenStreams(Connection);
1203
1204    impl<B> quic::OpenStreams<B> for OpenStreams
1205    where
1206        B: Buf,
1207    {
1208        type BidiStream = BidiStream<B>;
1209        type SendStream = SendStream<B>;
1210
1211        fn poll_open_bidi(
1212            &mut self,
1213            cx: &mut Context<'_>,
1214        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1215            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1216            Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1217        }
1218
1219        fn poll_open_send(
1220            &mut self,
1221            cx: &mut Context<'_>,
1222        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1223            let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1224            Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1225        }
1226
1227        fn close(&mut self, code: Code, reason: &[u8]) {
1228            self.0
1229                .close(code.value().try_into().expect("invalid code"), reason)
1230        }
1231    }
1232
1233    impl<B> quic::OpenStreams<B> for Connection
1234    where
1235        B: Buf,
1236    {
1237        type BidiStream = BidiStream<B>;
1238        type SendStream = SendStream<B>;
1239
1240        fn poll_open_bidi(
1241            &mut self,
1242            cx: &mut Context<'_>,
1243        ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1244            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1245            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1246        }
1247
1248        fn poll_open_send(
1249            &mut self,
1250            cx: &mut Context<'_>,
1251        ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1252            let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1253            Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1254        }
1255
1256        fn close(&mut self, code: Code, reason: &[u8]) {
1257            Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1258        }
1259    }
1260
1261    impl<B> quic::Connection<B> for Connection
1262    where
1263        B: Buf,
1264    {
1265        type OpenStreams = OpenStreams;
1266        type RecvStream = RecvStream;
1267
1268        fn poll_accept_recv(
1269            &mut self,
1270            cx: &mut std::task::Context<'_>,
1271        ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1272            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1273            Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1274        }
1275
1276        fn poll_accept_bidi(
1277            &mut self,
1278            cx: &mut std::task::Context<'_>,
1279        ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1280            let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1281            Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1282        }
1283
1284        fn opener(&self) -> Self::OpenStreams {
1285            OpenStreams(self.clone())
1286        }
1287    }
1288}