1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 future::poll_fn,
5 io,
6 mem::ManuallyDrop,
7 net::{SocketAddr, SocketAddrV6},
8 ops::Deref,
9 pin::pin,
10 ptr,
11 sync::Arc,
12 task::{Context, Poll, Waker},
13 time::Instant,
14};
15
16use compio_buf::{BufResult, bytes::Bytes};
17use compio_log::{Instrument, error};
18#[cfg(rustls)]
19use compio_net::ToSocketAddrsAsync;
20use compio_net::UdpSocket;
21use compio_runtime::JoinHandle;
22use flume::{Receiver, Sender, unbounded};
23use futures_util::{FutureExt, StreamExt, future, select, task::AtomicWaker};
24use quinn_proto::{
25 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
26 EndpointEvent, ServerConfig, Transmit, VarInt,
27};
28use rustc_hash::FxHashMap as HashMap;
29
30use crate::{
31 Connecting, ConnectionEvent, Incoming, RecvMeta, Socket,
32 sync::{mutex_blocking::Mutex, shared::Shared},
33};
34
35#[derive(Debug)]
36struct EndpointState {
37 endpoint: quinn_proto::Endpoint,
38 worker: Option<JoinHandle<()>>,
39 connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
40 close: Option<(VarInt, Bytes)>,
41 exit_on_idle: bool,
42 incoming: VecDeque<quinn_proto::Incoming>,
43 incoming_wakers: VecDeque<Waker>,
44 stats: EndpointStats,
45}
46
47#[non_exhaustive]
49#[derive(Debug, Default, Copy, Clone)]
50pub struct EndpointStats {
51 pub accepted_handshakes: u64,
53 pub outgoing_handshakes: u64,
55 pub refused_handshakes: u64,
57 pub ignored_handshakes: u64,
59}
60
61impl EndpointState {
62 fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
63 let now = Instant::now();
64 for data in buf[..meta.len]
65 .chunks(meta.stride.min(meta.len))
66 .map(Into::into)
67 {
68 let mut resp_buf = Vec::new();
69 match self.endpoint.handle(
70 now,
71 meta.remote,
72 meta.local_ip,
73 meta.ecn,
74 data,
75 &mut resp_buf,
76 ) {
77 Some(DatagramEvent::NewConnection(incoming)) => {
78 if self.close.is_none() {
79 self.incoming.push_back(incoming);
80 } else {
81 let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
82 respond_fn(resp_buf, transmit);
83 }
84 }
85 Some(DatagramEvent::ConnectionEvent(ch, event)) => {
86 let _ = self
87 .connections
88 .get(&ch)
89 .unwrap()
90 .send(ConnectionEvent::Proto(event));
91 }
92 Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
93 None => {}
94 }
95 }
96 }
97
98 fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
99 if event.is_drained() {
100 self.connections.remove(&ch);
101 }
102 if let Some(event) = self.endpoint.handle_event(ch, event) {
103 let _ = self
104 .connections
105 .get(&ch)
106 .unwrap()
107 .send(ConnectionEvent::Proto(event));
108 }
109 }
110
111 fn is_idle(&self) -> bool {
112 self.connections.is_empty()
113 }
114
115 fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
116 if self.close.is_none() {
117 if let Some(incoming) = self.incoming.pop_front() {
118 Poll::Ready(Some(incoming))
119 } else {
120 self.incoming_wakers.push_back(cx.waker().clone());
121 Poll::Pending
122 }
123 } else {
124 Poll::Ready(None)
125 }
126 }
127
128 fn new_connection(
129 &mut self,
130 handle: ConnectionHandle,
131 conn: quinn_proto::Connection,
132 socket: Socket,
133 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
134 ) -> Connecting {
135 let (tx, rx) = unbounded();
136 if let Some((error_code, reason)) = &self.close {
137 tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
138 .unwrap();
139 }
140 self.connections.insert(handle, tx);
141 Connecting::new(handle, conn, socket, events_tx, rx)
142 }
143}
144
145impl Drop for EndpointState {
146 fn drop(&mut self) {
147 for incoming in self.incoming.drain(..) {
148 self.endpoint.ignore(incoming);
149 }
150 }
151}
152
153type ChannelPair<T> = (Sender<T>, Receiver<T>);
154
155#[derive(Debug)]
156pub(crate) struct EndpointInner {
157 state: Mutex<EndpointState>,
158 socket: Socket,
159 ipv6: bool,
160 events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
161 done: AtomicWaker,
162}
163
164impl EndpointInner {
165 fn new(
166 socket: UdpSocket,
167 config: EndpointConfig,
168 server_config: Option<ServerConfig>,
169 ) -> io::Result<Self> {
170 let socket = Socket::new(socket)?;
171 let ipv6 = socket.local_addr()?.is_ipv6();
172 let allow_mtud = !socket.may_fragment();
173
174 Ok(Self {
175 state: Mutex::new(EndpointState {
176 endpoint: quinn_proto::Endpoint::new(
177 Arc::new(config),
178 server_config.map(Arc::new),
179 allow_mtud,
180 None,
181 ),
182 worker: None,
183 connections: HashMap::default(),
184 close: None,
185 exit_on_idle: false,
186 incoming: VecDeque::new(),
187 incoming_wakers: VecDeque::new(),
188 stats: EndpointStats::default(),
189 }),
190 socket,
191 ipv6,
192 events: unbounded(),
193 done: AtomicWaker::new(),
194 })
195 }
196
197 fn connect(
198 &self,
199 remote: SocketAddr,
200 server_name: &str,
201 config: ClientConfig,
202 ) -> Result<Connecting, ConnectError> {
203 let mut state = self.state.lock();
204
205 if state.worker.is_none() {
206 return Err(ConnectError::EndpointStopping);
207 }
208 if remote.is_ipv6() && !self.ipv6 {
209 return Err(ConnectError::InvalidRemoteAddress(remote));
210 }
211 let remote = if self.ipv6 {
212 SocketAddr::V6(match remote {
213 SocketAddr::V4(addr) => {
214 SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
215 }
216 SocketAddr::V6(addr) => addr,
217 })
218 } else {
219 remote
220 };
221
222 let (handle, conn) = state
223 .endpoint
224 .connect(Instant::now(), config, remote, server_name)?;
225 state.stats.outgoing_handshakes += 1;
226
227 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
228 }
229
230 fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
231 let socket = self.socket.clone();
232 compio_runtime::spawn(async move {
233 socket.send(buf, &transmit).await;
234 })
235 .detach();
236 }
237
238 pub(crate) fn accept(
239 &self,
240 incoming: quinn_proto::Incoming,
241 server_config: Option<ServerConfig>,
242 ) -> Result<Connecting, ConnectionError> {
243 let mut state = self.state.lock();
244 let mut resp_buf = Vec::new();
245 let now = Instant::now();
246 match state
247 .endpoint
248 .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
249 {
250 Ok((handle, conn)) => {
251 state.stats.accepted_handshakes += 1;
252 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
253 }
254 Err(err) => {
255 if let Some(transmit) = err.response {
256 self.respond(resp_buf, transmit);
257 }
258 Err(err.cause)
259 }
260 }
261 }
262
263 pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
264 let mut state = self.state.lock();
265 state.stats.refused_handshakes += 1;
266 let mut resp_buf = Vec::new();
267 let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
268 self.respond(resp_buf, transmit);
269 }
270
271 #[allow(clippy::result_large_err)]
272 pub(crate) fn retry(
273 &self,
274 incoming: quinn_proto::Incoming,
275 ) -> Result<(), quinn_proto::RetryError> {
276 let mut state = self.state.lock();
277 let mut resp_buf = Vec::new();
278 let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
279 self.respond(resp_buf, transmit);
280 Ok(())
281 }
282
283 pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
284 let mut state = self.state.lock();
285 state.stats.ignored_handshakes += 1;
286 state.endpoint.ignore(incoming);
287 }
288
289 async fn run(&self) -> io::Result<()> {
290 let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
291
292 let mut recv_fut = pin!(
293 self.socket
294 .recv(Vec::with_capacity(
295 self.state
296 .lock()
297 .endpoint
298 .config()
299 .get_max_udp_payload_size()
300 .min(64 * 1024) as usize
301 * self.socket.max_gro_segments(),
302 ))
303 .fuse()
304 );
305
306 let mut event_stream = self.events.1.stream().ready_chunks(100);
307
308 loop {
309 let mut state = select! {
310 BufResult(res, recv_buf) = recv_fut => {
311 let mut state = self.state.lock();
312 match res {
313 Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
314 Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
315 #[cfg(windows)]
316 Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
317 Err(e) => break Err(e),
318 }
319 recv_fut.set(self.socket.recv(recv_buf).fuse());
320 state
321 },
322 events = event_stream.select_next_some() => {
323 let mut state = self.state.lock();
324 for (ch, event) in events {
325 state.handle_event(ch, event);
326 }
327 state
328 },
329 };
330
331 if state.exit_on_idle && state.is_idle() {
332 break Ok(());
333 }
334 if !state.incoming.is_empty() {
335 let n = state.incoming.len().min(state.incoming_wakers.len());
336 state.incoming_wakers.drain(..n).for_each(Waker::wake);
337 }
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
343pub(crate) struct EndpointRef(Shared<EndpointInner>);
344
345impl EndpointRef {
346 fn into_inner(self) -> Shared<EndpointInner> {
347 let this = ManuallyDrop::new(self);
348 unsafe { ptr::read(&this.0) }
350 }
351
352 async fn shutdown(self) -> io::Result<()> {
353 let (worker, idle) = {
354 let mut state = self.0.state.lock();
355 let idle = state.is_idle();
356 if !idle {
357 state.exit_on_idle = true;
358 }
359 (state.worker.take(), idle)
360 };
361 if let Some(worker) = worker {
362 if idle {
363 worker.cancel().await;
364 } else {
365 let _ = worker.await;
366 }
367 }
368
369 let mut this = Some(self.into_inner());
370 let inner = poll_fn(move |cx| {
371 let s = match Shared::try_unwrap(this.take().unwrap()) {
372 Ok(inner) => return Poll::Ready(inner),
373 Err(s) => s,
374 };
375
376 s.done.register(cx.waker());
377
378 match Shared::try_unwrap(s) {
379 Ok(inner) => Poll::Ready(inner),
380 Err(s) => {
381 this.replace(s);
382 Poll::Pending
383 }
384 }
385 })
386 .await;
387
388 inner.socket.close().await
389 }
390}
391
392impl Drop for EndpointRef {
393 fn drop(&mut self) {
394 if Shared::strong_count(&self.0) == 2 {
395 self.0.done.wake();
398 self.0.state.lock().exit_on_idle = true;
400 }
401 }
402}
403
404impl Deref for EndpointRef {
405 type Target = EndpointInner;
406
407 fn deref(&self) -> &Self::Target {
408 &self.0
409 }
410}
411
412#[derive(Debug, Clone)]
414pub struct Endpoint {
415 inner: EndpointRef,
416 pub default_client_config: Option<ClientConfig>,
418}
419
420impl Endpoint {
421 pub fn new(
423 socket: UdpSocket,
424 config: EndpointConfig,
425 server_config: Option<ServerConfig>,
426 default_client_config: Option<ClientConfig>,
427 ) -> io::Result<Self> {
428 let inner = EndpointRef(Shared::new(EndpointInner::new(
429 socket,
430 config,
431 server_config,
432 )?));
433 let worker = compio_runtime::spawn({
434 let inner = inner.clone();
435 async move {
436 #[allow(unused)]
437 if let Err(e) = inner.run().await {
438 error!("I/O error: {}", e);
439 }
440 }
441 .in_current_span()
442 });
443 inner.state.lock().worker = Some(worker);
444 Ok(Self {
445 inner,
446 default_client_config,
447 })
448 }
449
450 #[cfg(rustls)]
465 pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
466 let socket = UdpSocket::bind(addr).await?;
468 Self::new(socket, EndpointConfig::default(), None, None)
469 }
470
471 #[cfg(rustls)]
480 pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
481 let socket = UdpSocket::bind(addr).await?;
482 Self::new(socket, EndpointConfig::default(), Some(config), None)
483 }
484
485 pub fn stats(&self) -> EndpointStats {
487 self.inner.state.lock().stats
488 }
489
490 pub fn connect(
492 &self,
493 remote: SocketAddr,
494 server_name: &str,
495 config: Option<ClientConfig>,
496 ) -> Result<Connecting, ConnectError> {
497 let config = config
498 .or_else(|| self.default_client_config.clone())
499 .ok_or(ConnectError::NoDefaultClientConfig)?;
500
501 self.inner.connect(remote, server_name, config)
502 }
503
504 pub async fn wait_incoming(&self) -> Option<Incoming> {
513 future::poll_fn(|cx| self.inner.state.lock().poll_incoming(cx))
514 .await
515 .map(|incoming| Incoming::new(incoming, self.inner.clone()))
516 }
517
518 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
524 self.inner
525 .state
526 .lock()
527 .endpoint
528 .set_server_config(server_config.map(Arc::new))
529 }
530
531 pub fn local_addr(&self) -> io::Result<SocketAddr> {
533 self.inner.socket.local_addr()
534 }
535
536 pub fn open_connections(&self) -> usize {
538 self.inner.state.lock().endpoint.open_connections()
539 }
540
541 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
548 let reason = Bytes::copy_from_slice(reason);
549 let mut state = self.inner.state.lock();
550 if state.close.is_some() {
551 return;
552 }
553 state.close = Some((error_code, reason.clone()));
554 for conn in state.connections.values() {
555 let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
556 }
557 state.incoming_wakers.drain(..).for_each(Waker::wake);
558 }
559
560 pub async fn shutdown(self) -> io::Result<()> {
577 self.inner.shutdown().await
578 }
579}