compio_quic/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt::Debug,
4    io,
5    mem::ManuallyDrop,
6    net::{SocketAddr, SocketAddrV6},
7    ops::Deref,
8    pin::pin,
9    sync::Arc,
10    task::{Context, Poll, Waker},
11    time::Instant,
12};
13
14use compio_buf::{BufResult, bytes::Bytes};
15use compio_log::{Instrument, error};
16#[cfg(rustls)]
17use compio_net::ToSocketAddrsAsync;
18use compio_net::UdpSocket;
19use compio_runtime::JoinHandle;
20use flume::{Receiver, Sender, unbounded};
21use futures_util::{
22    FutureExt, StreamExt,
23    future::{self},
24    select,
25    task::AtomicWaker,
26};
27use quinn_proto::{
28    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
29    EndpointEvent, ServerConfig, Transmit, VarInt,
30};
31use rustc_hash::FxHashMap as HashMap;
32
33use crate::{
34    Connecting, ConnectionEvent, Incoming, RecvMeta, Socket,
35    sync::{mutex_blocking::Mutex, shared::Shared},
36};
37
38#[derive(Debug)]
39struct EndpointState {
40    endpoint: quinn_proto::Endpoint,
41    worker: Option<JoinHandle<()>>,
42    connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
43    close: Option<(VarInt, Bytes)>,
44    exit_on_idle: bool,
45    incoming: VecDeque<quinn_proto::Incoming>,
46    incoming_wakers: VecDeque<Waker>,
47}
48
49impl EndpointState {
50    fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
51        let now = Instant::now();
52        for data in buf[..meta.len]
53            .chunks(meta.stride.min(meta.len))
54            .map(Into::into)
55        {
56            let mut resp_buf = Vec::new();
57            match self.endpoint.handle(
58                now,
59                meta.remote,
60                meta.local_ip,
61                meta.ecn,
62                data,
63                &mut resp_buf,
64            ) {
65                Some(DatagramEvent::NewConnection(incoming)) => {
66                    if self.close.is_none() {
67                        self.incoming.push_back(incoming);
68                    } else {
69                        let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
70                        respond_fn(resp_buf, transmit);
71                    }
72                }
73                Some(DatagramEvent::ConnectionEvent(ch, event)) => {
74                    let _ = self
75                        .connections
76                        .get(&ch)
77                        .unwrap()
78                        .send(ConnectionEvent::Proto(event));
79                }
80                Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
81                None => {}
82            }
83        }
84    }
85
86    fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
87        if event.is_drained() {
88            self.connections.remove(&ch);
89        }
90        if let Some(event) = self.endpoint.handle_event(ch, event) {
91            let _ = self
92                .connections
93                .get(&ch)
94                .unwrap()
95                .send(ConnectionEvent::Proto(event));
96        }
97    }
98
99    fn is_idle(&self) -> bool {
100        self.connections.is_empty()
101    }
102
103    fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
104        if self.close.is_none() {
105            if let Some(incoming) = self.incoming.pop_front() {
106                Poll::Ready(Some(incoming))
107            } else {
108                self.incoming_wakers.push_back(cx.waker().clone());
109                Poll::Pending
110            }
111        } else {
112            Poll::Ready(None)
113        }
114    }
115
116    fn new_connection(
117        &mut self,
118        handle: ConnectionHandle,
119        conn: quinn_proto::Connection,
120        socket: Socket,
121        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
122    ) -> Connecting {
123        let (tx, rx) = unbounded();
124        if let Some((error_code, reason)) = &self.close {
125            tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
126                .unwrap();
127        }
128        self.connections.insert(handle, tx);
129        Connecting::new(handle, conn, socket, events_tx, rx)
130    }
131}
132
133type ChannelPair<T> = (Sender<T>, Receiver<T>);
134
135#[derive(Debug)]
136pub(crate) struct EndpointInner {
137    state: Mutex<EndpointState>,
138    socket: Socket,
139    ipv6: bool,
140    events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
141    done: AtomicWaker,
142}
143
144impl EndpointInner {
145    fn new(
146        socket: UdpSocket,
147        config: EndpointConfig,
148        server_config: Option<ServerConfig>,
149    ) -> io::Result<Self> {
150        let socket = Socket::new(socket)?;
151        let ipv6 = socket.local_addr()?.is_ipv6();
152        let allow_mtud = !socket.may_fragment();
153
154        Ok(Self {
155            state: Mutex::new(EndpointState {
156                endpoint: quinn_proto::Endpoint::new(
157                    Arc::new(config),
158                    server_config.map(Arc::new),
159                    allow_mtud,
160                    None,
161                ),
162                worker: None,
163                connections: HashMap::default(),
164                close: None,
165                exit_on_idle: false,
166                incoming: VecDeque::new(),
167                incoming_wakers: VecDeque::new(),
168            }),
169            socket,
170            ipv6,
171            events: unbounded(),
172            done: AtomicWaker::new(),
173        })
174    }
175
176    fn connect(
177        &self,
178        remote: SocketAddr,
179        server_name: &str,
180        config: ClientConfig,
181    ) -> Result<Connecting, ConnectError> {
182        let mut state = self.state.lock();
183
184        if state.worker.is_none() {
185            return Err(ConnectError::EndpointStopping);
186        }
187        if remote.is_ipv6() && !self.ipv6 {
188            return Err(ConnectError::InvalidRemoteAddress(remote));
189        }
190        let remote = if self.ipv6 {
191            SocketAddr::V6(match remote {
192                SocketAddr::V4(addr) => {
193                    SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
194                }
195                SocketAddr::V6(addr) => addr,
196            })
197        } else {
198            remote
199        };
200
201        let (handle, conn) = state
202            .endpoint
203            .connect(Instant::now(), config, remote, server_name)?;
204
205        Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
206    }
207
208    fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
209        let socket = self.socket.clone();
210        compio_runtime::spawn(async move {
211            socket.send(buf, &transmit).await;
212        })
213        .detach();
214    }
215
216    pub(crate) fn accept(
217        &self,
218        incoming: quinn_proto::Incoming,
219        server_config: Option<ServerConfig>,
220    ) -> Result<Connecting, ConnectionError> {
221        let mut state = self.state.lock();
222        let mut resp_buf = Vec::new();
223        let now = Instant::now();
224        match state
225            .endpoint
226            .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
227        {
228            Ok((handle, conn)) => {
229                Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
230            }
231            Err(err) => {
232                if let Some(transmit) = err.response {
233                    self.respond(resp_buf, transmit);
234                }
235                Err(err.cause)
236            }
237        }
238    }
239
240    pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
241        let mut state = self.state.lock();
242        let mut resp_buf = Vec::new();
243        let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
244        self.respond(resp_buf, transmit);
245    }
246
247    #[allow(clippy::result_large_err)]
248    pub(crate) fn retry(
249        &self,
250        incoming: quinn_proto::Incoming,
251    ) -> Result<(), quinn_proto::RetryError> {
252        let mut state = self.state.lock();
253        let mut resp_buf = Vec::new();
254        let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
255        self.respond(resp_buf, transmit);
256        Ok(())
257    }
258
259    pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
260        let mut state = self.state.lock();
261        state.endpoint.ignore(incoming);
262    }
263
264    async fn run(&self) -> io::Result<()> {
265        let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
266
267        let mut recv_fut = pin!(
268            self.socket
269                .recv(Vec::with_capacity(
270                    self.state
271                        .lock()
272                        .endpoint
273                        .config()
274                        .get_max_udp_payload_size()
275                        .min(64 * 1024) as usize
276                        * self.socket.max_gro_segments(),
277                ))
278                .fuse()
279        );
280
281        let mut event_stream = self.events.1.stream().ready_chunks(100);
282
283        loop {
284            let mut state = select! {
285                BufResult(res, recv_buf) = recv_fut => {
286                    let mut state = self.state.lock();
287                    match res {
288                        Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
289                        Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
290                        #[cfg(windows)]
291                        Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
292                        Err(e) => break Err(e),
293                    }
294                    recv_fut.set(self.socket.recv(recv_buf).fuse());
295                    state
296                },
297                events = event_stream.select_next_some() => {
298                    let mut state = self.state.lock();
299                    for (ch, event) in events {
300                        state.handle_event(ch, event);
301                    }
302                    state
303                },
304            };
305
306            if state.exit_on_idle && state.is_idle() {
307                break Ok(());
308            }
309            if !state.incoming.is_empty() {
310                let n = state.incoming.len().min(state.incoming_wakers.len());
311                state.incoming_wakers.drain(..n).for_each(Waker::wake);
312            }
313        }
314    }
315}
316
317#[derive(Debug, Clone)]
318pub(crate) struct EndpointRef(Shared<EndpointInner>);
319
320impl EndpointRef {
321    // Modified from [`SharedFd::try_unwrap_inner`], see notes there.
322    unsafe fn try_unwrap_inner(&self) -> Option<EndpointInner> {
323        let ptr = unsafe { std::ptr::read(&self.0) };
324        match Shared::try_unwrap(ptr) {
325            Ok(inner) => Some(inner),
326            Err(ptr) => {
327                std::mem::forget(ptr);
328                None
329            }
330        }
331    }
332
333    async fn shutdown(self) -> io::Result<()> {
334        let (worker, idle) = {
335            let mut state = self.0.state.lock();
336            let idle = state.is_idle();
337            if !idle {
338                state.exit_on_idle = true;
339            }
340            (state.worker.take(), idle)
341        };
342        if let Some(worker) = worker {
343            if idle {
344                worker.cancel().await;
345            } else {
346                let _ = worker.await;
347            }
348        }
349
350        let this = ManuallyDrop::new(self);
351        let inner = future::poll_fn(move |cx| {
352            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
353                return Poll::Ready(inner);
354            }
355
356            this.done.register(cx.waker());
357
358            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
359                Poll::Ready(inner)
360            } else {
361                Poll::Pending
362            }
363        })
364        .await;
365
366        inner.socket.close().await
367    }
368}
369
370impl Drop for EndpointRef {
371    fn drop(&mut self) {
372        if Shared::strong_count(&self.0) == 2 {
373            // There are actually two cases:
374            // 1. User is trying to shutdown the socket.
375            self.0.done.wake();
376            // 2. User dropped the endpoint but the worker is still running.
377            self.0.state.lock().exit_on_idle = true;
378        }
379    }
380}
381
382impl Deref for EndpointRef {
383    type Target = EndpointInner;
384
385    fn deref(&self) -> &Self::Target {
386        &self.0
387    }
388}
389
390/// A QUIC endpoint.
391#[derive(Debug, Clone)]
392pub struct Endpoint {
393    inner: EndpointRef,
394    /// The client configuration used by `connect`
395    pub default_client_config: Option<ClientConfig>,
396}
397
398impl Endpoint {
399    /// Create a QUIC endpoint.
400    pub fn new(
401        socket: UdpSocket,
402        config: EndpointConfig,
403        server_config: Option<ServerConfig>,
404        default_client_config: Option<ClientConfig>,
405    ) -> io::Result<Self> {
406        let inner = EndpointRef(Shared::new(EndpointInner::new(
407            socket,
408            config,
409            server_config,
410        )?));
411        let worker = compio_runtime::spawn({
412            let inner = inner.clone();
413            async move {
414                #[allow(unused)]
415                if let Err(e) = inner.run().await {
416                    error!("I/O error: {}", e);
417                }
418            }
419            .in_current_span()
420        });
421        inner.state.lock().worker = Some(worker);
422        Ok(Self {
423            inner,
424            default_client_config,
425        })
426    }
427
428    /// Helper to construct an endpoint for use with outgoing connections only.
429    ///
430    /// Note that `addr` is the *local* address to bind to, which should usually
431    /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow
432    /// communication with any reachable IPv4 or IPv6 address respectively
433    /// from an OS-assigned port.
434    ///
435    /// If an IPv6 address is provided, the socket may dual-stack depending on
436    /// the platform, so as to allow communication with both IPv4 and IPv6
437    /// addresses. As such, calling this method with the address `[::]:0` is a
438    /// reasonable default to maximize the ability to connect to other
439    /// address.
440    ///
441    /// IPv4 client is never dual-stack.
442    #[cfg(rustls)]
443    pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
444        // TODO: try to enable dual-stack on all platforms, notably Windows
445        let socket = UdpSocket::bind(addr).await?;
446        Self::new(socket, EndpointConfig::default(), None, None)
447    }
448
449    /// Helper to construct an endpoint for use with both incoming and outgoing
450    /// connections
451    ///
452    /// Platform defaults for dual-stack sockets vary. For example, any socket
453    /// bound to a wildcard IPv6 address on Windows will not by default be
454    /// able to communicate with IPv4 addresses. Portable applications
455    /// should bind an address that matches the family they wish to
456    /// communicate within.
457    #[cfg(rustls)]
458    pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
459        let socket = UdpSocket::bind(addr).await?;
460        Self::new(socket, EndpointConfig::default(), Some(config), None)
461    }
462
463    /// Connect to a remote endpoint.
464    pub fn connect(
465        &self,
466        remote: SocketAddr,
467        server_name: &str,
468        config: Option<ClientConfig>,
469    ) -> Result<Connecting, ConnectError> {
470        let config = config
471            .or_else(|| self.default_client_config.clone())
472            .ok_or(ConnectError::NoDefaultClientConfig)?;
473
474        self.inner.connect(remote, server_name, config)
475    }
476
477    /// Wait for the next incoming connection attempt from a client.
478    ///
479    /// Yields [`Incoming`]s, or `None` if the endpoint is
480    /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the
481    /// final [`Connection`](crate::Connection), or used to e.g. filter
482    /// connection attempts or force address validation, or converted into an
483    /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT
484    /// data.
485    pub async fn wait_incoming(&self) -> Option<Incoming> {
486        future::poll_fn(|cx| self.inner.state.lock().poll_incoming(cx))
487            .await
488            .map(|incoming| Incoming::new(incoming, self.inner.clone()))
489    }
490
491    /// Replace the server configuration, affecting new incoming connections
492    /// only.
493    ///
494    /// Useful for e.g. refreshing TLS certificates without disrupting existing
495    /// connections.
496    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
497        self.inner
498            .state
499            .lock()
500            .endpoint
501            .set_server_config(server_config.map(Arc::new))
502    }
503
504    /// Get the local `SocketAddr` the underlying socket is bound to.
505    pub fn local_addr(&self) -> io::Result<SocketAddr> {
506        self.inner.socket.local_addr()
507    }
508
509    /// Get the number of connections that are currently open.
510    pub fn open_connections(&self) -> usize {
511        self.inner.state.lock().endpoint.open_connections()
512    }
513
514    /// Close all of this endpoint's connections immediately and cease accepting
515    /// new connections.
516    ///
517    /// See [`Connection::close()`] for details.
518    ///
519    /// [`Connection::close()`]: crate::Connection::close
520    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
521        let reason = Bytes::copy_from_slice(reason);
522        let mut state = self.inner.state.lock();
523        if state.close.is_some() {
524            return;
525        }
526        state.close = Some((error_code, reason.clone()));
527        for conn in state.connections.values() {
528            let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
529        }
530        state.incoming_wakers.drain(..).for_each(Waker::wake);
531    }
532
533    /// Gracefully shutdown the endpoint.
534    ///
535    /// Wait for all connections on the endpoint to be cleanly shut down and
536    /// close the underlying socket. This will wait for all clones of the
537    /// endpoint, all connections and all streams to be dropped before
538    /// closing the socket.
539    ///
540    /// Waiting for this condition before exiting ensures that a good-faith
541    /// effort is made to notify peers of recent connection closes, whereas
542    /// exiting immediately could force them to wait out the idle timeout
543    /// period.
544    ///
545    /// Does not proactively close existing connections. Consider calling
546    /// [`close()`] if that is desired.
547    ///
548    /// [`close()`]: Endpoint::close
549    pub async fn shutdown(self) -> io::Result<()> {
550        self.inner.shutdown().await
551    }
552}