Skip to main content

compio_quic/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    future::poll_fn,
5    io,
6    mem::ManuallyDrop,
7    net::{SocketAddr, SocketAddrV6},
8    ops::Deref,
9    pin::pin,
10    ptr,
11    sync::Arc,
12    task::{Context, Poll, Waker},
13    time::Instant,
14};
15
16use compio_buf::{BufResult, bytes::Bytes};
17use compio_log::{Instrument, error};
18#[cfg(rustls)]
19use compio_net::ToSocketAddrsAsync;
20use compio_net::UdpSocket;
21use compio_runtime::JoinHandle;
22use flume::{Receiver, Sender, unbounded};
23use futures_util::{FutureExt, StreamExt, future, select, task::AtomicWaker};
24use quinn_proto::{
25    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
26    EndpointEvent, ServerConfig, Transmit, VarInt,
27};
28use rustc_hash::FxHashMap as HashMap;
29
30use crate::{
31    Connecting, ConnectionEvent, Incoming, RecvMeta, Socket,
32    sync::{mutex_blocking::Mutex, shared::Shared},
33};
34
35#[derive(Debug)]
36struct EndpointState {
37    endpoint: quinn_proto::Endpoint,
38    worker: Option<JoinHandle<()>>,
39    connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
40    close: Option<(VarInt, Bytes)>,
41    exit_on_idle: bool,
42    incoming: VecDeque<quinn_proto::Incoming>,
43    incoming_wakers: VecDeque<Waker>,
44    stats: EndpointStats,
45}
46
47/// Statistics on [Endpoint] activity
48#[non_exhaustive]
49#[derive(Debug, Default, Copy, Clone)]
50pub struct EndpointStats {
51    /// Cumulative number of Quic handshakes accepted by this [Endpoint]
52    pub accepted_handshakes: u64,
53    /// Cumulative number of Quic handshakes sent from this [Endpoint]
54    pub outgoing_handshakes: u64,
55    /// Cumulative number of Quic handshakes refused on this [Endpoint]
56    pub refused_handshakes: u64,
57    /// Cumulative number of Quic handshakes ignored on this [Endpoint]
58    pub ignored_handshakes: u64,
59}
60
61impl EndpointState {
62    fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
63        let now = Instant::now();
64        for data in buf[..meta.len]
65            .chunks(meta.stride.min(meta.len))
66            .map(Into::into)
67        {
68            let mut resp_buf = Vec::new();
69            match self.endpoint.handle(
70                now,
71                meta.remote,
72                meta.local_ip,
73                meta.ecn,
74                data,
75                &mut resp_buf,
76            ) {
77                Some(DatagramEvent::NewConnection(incoming)) => {
78                    if self.close.is_none() {
79                        self.incoming.push_back(incoming);
80                    } else {
81                        let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
82                        respond_fn(resp_buf, transmit);
83                    }
84                }
85                Some(DatagramEvent::ConnectionEvent(ch, event)) => {
86                    let _ = self
87                        .connections
88                        .get(&ch)
89                        .unwrap()
90                        .send(ConnectionEvent::Proto(event));
91                }
92                Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
93                None => {}
94            }
95        }
96    }
97
98    fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
99        if event.is_drained() {
100            self.connections.remove(&ch);
101        }
102        if let Some(event) = self.endpoint.handle_event(ch, event) {
103            let _ = self
104                .connections
105                .get(&ch)
106                .unwrap()
107                .send(ConnectionEvent::Proto(event));
108        }
109    }
110
111    fn is_idle(&self) -> bool {
112        self.connections.is_empty()
113    }
114
115    fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
116        if self.close.is_none() {
117            if let Some(incoming) = self.incoming.pop_front() {
118                Poll::Ready(Some(incoming))
119            } else {
120                self.incoming_wakers.push_back(cx.waker().clone());
121                Poll::Pending
122            }
123        } else {
124            Poll::Ready(None)
125        }
126    }
127
128    fn new_connection(
129        &mut self,
130        handle: ConnectionHandle,
131        conn: quinn_proto::Connection,
132        socket: Socket,
133        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
134    ) -> Connecting {
135        let (tx, rx) = unbounded();
136        if let Some((error_code, reason)) = &self.close {
137            tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
138                .unwrap();
139        }
140        self.connections.insert(handle, tx);
141        Connecting::new(handle, conn, socket, events_tx, rx)
142    }
143}
144
145impl Drop for EndpointState {
146    fn drop(&mut self) {
147        for incoming in self.incoming.drain(..) {
148            self.endpoint.ignore(incoming);
149        }
150    }
151}
152
153type ChannelPair<T> = (Sender<T>, Receiver<T>);
154
155#[derive(Debug)]
156pub(crate) struct EndpointInner {
157    state: Mutex<EndpointState>,
158    socket: Socket,
159    ipv6: bool,
160    events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
161    done: AtomicWaker,
162}
163
164impl EndpointInner {
165    fn new(
166        socket: UdpSocket,
167        config: EndpointConfig,
168        server_config: Option<ServerConfig>,
169    ) -> io::Result<Self> {
170        let socket = Socket::new(socket)?;
171        let ipv6 = socket.local_addr()?.is_ipv6();
172        let allow_mtud = !socket.may_fragment();
173
174        Ok(Self {
175            state: Mutex::new(EndpointState {
176                endpoint: quinn_proto::Endpoint::new(
177                    Arc::new(config),
178                    server_config.map(Arc::new),
179                    allow_mtud,
180                    None,
181                ),
182                worker: None,
183                connections: HashMap::default(),
184                close: None,
185                exit_on_idle: false,
186                incoming: VecDeque::new(),
187                incoming_wakers: VecDeque::new(),
188                stats: EndpointStats::default(),
189            }),
190            socket,
191            ipv6,
192            events: unbounded(),
193            done: AtomicWaker::new(),
194        })
195    }
196
197    fn connect(
198        &self,
199        remote: SocketAddr,
200        server_name: &str,
201        config: ClientConfig,
202    ) -> Result<Connecting, ConnectError> {
203        let mut state = self.state.lock();
204
205        if state.worker.is_none() {
206            return Err(ConnectError::EndpointStopping);
207        }
208        if remote.is_ipv6() && !self.ipv6 {
209            return Err(ConnectError::InvalidRemoteAddress(remote));
210        }
211        let remote = if self.ipv6 {
212            SocketAddr::V6(match remote {
213                SocketAddr::V4(addr) => {
214                    SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
215                }
216                SocketAddr::V6(addr) => addr,
217            })
218        } else {
219            remote
220        };
221
222        let (handle, conn) = state
223            .endpoint
224            .connect(Instant::now(), config, remote, server_name)?;
225        state.stats.outgoing_handshakes += 1;
226
227        Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
228    }
229
230    fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
231        let socket = self.socket.clone();
232        compio_runtime::spawn(async move {
233            socket.send(buf, &transmit).await;
234        })
235        .detach();
236    }
237
238    pub(crate) fn accept(
239        &self,
240        incoming: quinn_proto::Incoming,
241        server_config: Option<ServerConfig>,
242    ) -> Result<Connecting, ConnectionError> {
243        let mut state = self.state.lock();
244        let mut resp_buf = Vec::new();
245        let now = Instant::now();
246        match state
247            .endpoint
248            .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
249        {
250            Ok((handle, conn)) => {
251                state.stats.accepted_handshakes += 1;
252                Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
253            }
254            Err(err) => {
255                if let Some(transmit) = err.response {
256                    self.respond(resp_buf, transmit);
257                }
258                Err(err.cause)
259            }
260        }
261    }
262
263    pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
264        let mut state = self.state.lock();
265        state.stats.refused_handshakes += 1;
266        let mut resp_buf = Vec::new();
267        let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
268        self.respond(resp_buf, transmit);
269    }
270
271    #[allow(clippy::result_large_err)]
272    pub(crate) fn retry(
273        &self,
274        incoming: quinn_proto::Incoming,
275    ) -> Result<(), quinn_proto::RetryError> {
276        let mut state = self.state.lock();
277        let mut resp_buf = Vec::new();
278        let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
279        self.respond(resp_buf, transmit);
280        Ok(())
281    }
282
283    pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
284        let mut state = self.state.lock();
285        state.stats.ignored_handshakes += 1;
286        state.endpoint.ignore(incoming);
287    }
288
289    async fn run(&self) -> io::Result<()> {
290        let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
291
292        let mut recv_fut = pin!(
293            self.socket
294                .recv(Vec::with_capacity(
295                    self.state
296                        .lock()
297                        .endpoint
298                        .config()
299                        .get_max_udp_payload_size()
300                        .min(64 * 1024) as usize
301                        * self.socket.max_gro_segments(),
302                ))
303                .fuse()
304        );
305
306        let mut event_stream = self.events.1.stream().ready_chunks(100);
307
308        loop {
309            let mut state = select! {
310                BufResult(res, recv_buf) = recv_fut => {
311                    let mut state = self.state.lock();
312                    match res {
313                        Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
314                        Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
315                        #[cfg(windows)]
316                        Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
317                        Err(e) => break Err(e),
318                    }
319                    recv_fut.set(self.socket.recv(recv_buf).fuse());
320                    state
321                },
322                events = event_stream.select_next_some() => {
323                    let mut state = self.state.lock();
324                    for (ch, event) in events {
325                        state.handle_event(ch, event);
326                    }
327                    state
328                },
329            };
330
331            if state.exit_on_idle && state.is_idle() {
332                break Ok(());
333            }
334            if !state.incoming.is_empty() {
335                let n = state.incoming.len().min(state.incoming_wakers.len());
336                state.incoming_wakers.drain(..n).for_each(Waker::wake);
337            }
338        }
339    }
340}
341
342#[derive(Debug, Clone)]
343pub(crate) struct EndpointRef(Shared<EndpointInner>);
344
345impl EndpointRef {
346    fn into_inner(self) -> Shared<EndpointInner> {
347        let this = ManuallyDrop::new(self);
348        // SAFETY: `this` is not dropped here, and we're consuming Self
349        unsafe { ptr::read(&this.0) }
350    }
351
352    async fn shutdown(self) -> io::Result<()> {
353        let (worker, idle) = {
354            let mut state = self.0.state.lock();
355            let idle = state.is_idle();
356            if !idle {
357                state.exit_on_idle = true;
358            }
359            (state.worker.take(), idle)
360        };
361        if let Some(worker) = worker {
362            if idle {
363                worker.cancel().await;
364            } else {
365                let _ = worker.await;
366            }
367        }
368
369        let mut this = Some(self.into_inner());
370        let inner = poll_fn(move |cx| {
371            let s = match Shared::try_unwrap(this.take().unwrap()) {
372                Ok(inner) => return Poll::Ready(inner),
373                Err(s) => s,
374            };
375
376            s.done.register(cx.waker());
377
378            match Shared::try_unwrap(s) {
379                Ok(inner) => Poll::Ready(inner),
380                Err(s) => {
381                    this.replace(s);
382                    Poll::Pending
383                }
384            }
385        })
386        .await;
387
388        inner.socket.close().await
389    }
390}
391
392impl Drop for EndpointRef {
393    fn drop(&mut self) {
394        if Shared::strong_count(&self.0) == 2 {
395            // There are actually two cases:
396            // 1. User is trying to shutdown the socket.
397            self.0.done.wake();
398            // 2. User dropped the endpoint but the worker is still running.
399            self.0.state.lock().exit_on_idle = true;
400        }
401    }
402}
403
404impl Deref for EndpointRef {
405    type Target = EndpointInner;
406
407    fn deref(&self) -> &Self::Target {
408        &self.0
409    }
410}
411
412/// A QUIC endpoint.
413#[derive(Debug, Clone)]
414pub struct Endpoint {
415    inner: EndpointRef,
416    /// The client configuration used by `connect`
417    pub default_client_config: Option<ClientConfig>,
418}
419
420impl Endpoint {
421    /// Create a QUIC endpoint.
422    pub fn new(
423        socket: UdpSocket,
424        config: EndpointConfig,
425        server_config: Option<ServerConfig>,
426        default_client_config: Option<ClientConfig>,
427    ) -> io::Result<Self> {
428        let inner = EndpointRef(Shared::new(EndpointInner::new(
429            socket,
430            config,
431            server_config,
432        )?));
433        let worker = compio_runtime::spawn({
434            let inner = inner.clone();
435            async move {
436                #[allow(unused)]
437                if let Err(e) = inner.run().await {
438                    error!("I/O error: {}", e);
439                }
440            }
441            .in_current_span()
442        });
443        inner.state.lock().worker = Some(worker);
444        Ok(Self {
445            inner,
446            default_client_config,
447        })
448    }
449
450    /// Helper to construct an endpoint for use with outgoing connections only.
451    ///
452    /// Note that `addr` is the *local* address to bind to, which should usually
453    /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow
454    /// communication with any reachable IPv4 or IPv6 address respectively
455    /// from an OS-assigned port.
456    ///
457    /// If an IPv6 address is provided, the socket may dual-stack depending on
458    /// the platform, so as to allow communication with both IPv4 and IPv6
459    /// addresses. As such, calling this method with the address `[::]:0` is a
460    /// reasonable default to maximize the ability to connect to other
461    /// address.
462    ///
463    /// IPv4 client is never dual-stack.
464    #[cfg(rustls)]
465    pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
466        // TODO: try to enable dual-stack on all platforms, notably Windows
467        let socket = UdpSocket::bind(addr).await?;
468        Self::new(socket, EndpointConfig::default(), None, None)
469    }
470
471    /// Helper to construct an endpoint for use with both incoming and outgoing
472    /// connections
473    ///
474    /// Platform defaults for dual-stack sockets vary. For example, any socket
475    /// bound to a wildcard IPv6 address on Windows will not by default be
476    /// able to communicate with IPv4 addresses. Portable applications
477    /// should bind an address that matches the family they wish to
478    /// communicate within.
479    #[cfg(rustls)]
480    pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
481        let socket = UdpSocket::bind(addr).await?;
482        Self::new(socket, EndpointConfig::default(), Some(config), None)
483    }
484
485    /// Returns relevant stats from this Endpoint
486    pub fn stats(&self) -> EndpointStats {
487        self.inner.state.lock().stats
488    }
489
490    /// Connect to a remote endpoint.
491    pub fn connect(
492        &self,
493        remote: SocketAddr,
494        server_name: &str,
495        config: Option<ClientConfig>,
496    ) -> Result<Connecting, ConnectError> {
497        let config = config
498            .or_else(|| self.default_client_config.clone())
499            .ok_or(ConnectError::NoDefaultClientConfig)?;
500
501        self.inner.connect(remote, server_name, config)
502    }
503
504    /// Wait for the next incoming connection attempt from a client.
505    ///
506    /// Yields [`Incoming`]s, or `None` if the endpoint is
507    /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the
508    /// final [`Connection`](crate::Connection), or used to e.g. filter
509    /// connection attempts or force address validation, or converted into an
510    /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT
511    /// data.
512    pub async fn wait_incoming(&self) -> Option<Incoming> {
513        future::poll_fn(|cx| self.inner.state.lock().poll_incoming(cx))
514            .await
515            .map(|incoming| Incoming::new(incoming, self.inner.clone()))
516    }
517
518    /// Replace the server configuration, affecting new incoming connections
519    /// only.
520    ///
521    /// Useful for e.g. refreshing TLS certificates without disrupting existing
522    /// connections.
523    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
524        self.inner
525            .state
526            .lock()
527            .endpoint
528            .set_server_config(server_config.map(Arc::new))
529    }
530
531    /// Get the local `SocketAddr` the underlying socket is bound to.
532    pub fn local_addr(&self) -> io::Result<SocketAddr> {
533        self.inner.socket.local_addr()
534    }
535
536    /// Get the number of connections that are currently open.
537    pub fn open_connections(&self) -> usize {
538        self.inner.state.lock().endpoint.open_connections()
539    }
540
541    /// Close all of this endpoint's connections immediately and cease accepting
542    /// new connections.
543    ///
544    /// See [`Connection::close()`] for details.
545    ///
546    /// [`Connection::close()`]: crate::Connection::close
547    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
548        let reason = Bytes::copy_from_slice(reason);
549        let mut state = self.inner.state.lock();
550        if state.close.is_some() {
551            return;
552        }
553        state.close = Some((error_code, reason.clone()));
554        for conn in state.connections.values() {
555            let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
556        }
557        state.incoming_wakers.drain(..).for_each(Waker::wake);
558    }
559
560    /// Gracefully shutdown the endpoint.
561    ///
562    /// Wait for all connections on the endpoint to be cleanly shut down and
563    /// close the underlying socket. This will wait for all clones of the
564    /// endpoint, all connections and all streams to be dropped before
565    /// closing the socket.
566    ///
567    /// Waiting for this condition before exiting ensures that a good-faith
568    /// effort is made to notify peers of recent connection closes, whereas
569    /// exiting immediately could force them to wait out the idle timeout
570    /// period.
571    ///
572    /// Does not proactively close existing connections. Consider calling
573    /// [`close()`] if that is desired.
574    ///
575    /// [`close()`]: Endpoint::close
576    pub async fn shutdown(self) -> io::Result<()> {
577        self.inner.shutdown().await
578    }
579}