1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 io,
5 mem::ManuallyDrop,
6 net::{SocketAddr, SocketAddrV6},
7 ops::Deref,
8 pin::pin,
9 sync::Arc,
10 task::{Context, Poll, Waker},
11 time::Instant,
12};
13
14use compio_buf::{BufResult, bytes::Bytes};
15use compio_log::{Instrument, error};
16#[cfg(rustls)]
17use compio_net::ToSocketAddrsAsync;
18use compio_net::UdpSocket;
19use compio_runtime::JoinHandle;
20use flume::{Receiver, Sender, unbounded};
21use futures_util::{
22 FutureExt, StreamExt,
23 future::{self},
24 select,
25 task::AtomicWaker,
26};
27use quinn_proto::{
28 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
29 EndpointEvent, ServerConfig, Transmit, VarInt,
30};
31use rustc_hash::FxHashMap as HashMap;
32
33use crate::{
34 Connecting, ConnectionEvent, Incoming, RecvMeta, Socket,
35 sync::{mutex_blocking::Mutex, shared::Shared},
36};
37
38#[derive(Debug)]
39struct EndpointState {
40 endpoint: quinn_proto::Endpoint,
41 worker: Option<JoinHandle<()>>,
42 connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
43 close: Option<(VarInt, Bytes)>,
44 exit_on_idle: bool,
45 incoming: VecDeque<quinn_proto::Incoming>,
46 incoming_wakers: VecDeque<Waker>,
47}
48
49impl EndpointState {
50 fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
51 let now = Instant::now();
52 for data in buf[..meta.len]
53 .chunks(meta.stride.min(meta.len))
54 .map(Into::into)
55 {
56 let mut resp_buf = Vec::new();
57 match self.endpoint.handle(
58 now,
59 meta.remote,
60 meta.local_ip,
61 meta.ecn,
62 data,
63 &mut resp_buf,
64 ) {
65 Some(DatagramEvent::NewConnection(incoming)) => {
66 if self.close.is_none() {
67 self.incoming.push_back(incoming);
68 } else {
69 let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
70 respond_fn(resp_buf, transmit);
71 }
72 }
73 Some(DatagramEvent::ConnectionEvent(ch, event)) => {
74 let _ = self
75 .connections
76 .get(&ch)
77 .unwrap()
78 .send(ConnectionEvent::Proto(event));
79 }
80 Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
81 None => {}
82 }
83 }
84 }
85
86 fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
87 if event.is_drained() {
88 self.connections.remove(&ch);
89 }
90 if let Some(event) = self.endpoint.handle_event(ch, event) {
91 let _ = self
92 .connections
93 .get(&ch)
94 .unwrap()
95 .send(ConnectionEvent::Proto(event));
96 }
97 }
98
99 fn is_idle(&self) -> bool {
100 self.connections.is_empty()
101 }
102
103 fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
104 if self.close.is_none() {
105 if let Some(incoming) = self.incoming.pop_front() {
106 Poll::Ready(Some(incoming))
107 } else {
108 self.incoming_wakers.push_back(cx.waker().clone());
109 Poll::Pending
110 }
111 } else {
112 Poll::Ready(None)
113 }
114 }
115
116 fn new_connection(
117 &mut self,
118 handle: ConnectionHandle,
119 conn: quinn_proto::Connection,
120 socket: Socket,
121 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
122 ) -> Connecting {
123 let (tx, rx) = unbounded();
124 if let Some((error_code, reason)) = &self.close {
125 tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
126 .unwrap();
127 }
128 self.connections.insert(handle, tx);
129 Connecting::new(handle, conn, socket, events_tx, rx)
130 }
131}
132
133type ChannelPair<T> = (Sender<T>, Receiver<T>);
134
135#[derive(Debug)]
136pub(crate) struct EndpointInner {
137 state: Mutex<EndpointState>,
138 socket: Socket,
139 ipv6: bool,
140 events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
141 done: AtomicWaker,
142}
143
144impl EndpointInner {
145 fn new(
146 socket: UdpSocket,
147 config: EndpointConfig,
148 server_config: Option<ServerConfig>,
149 ) -> io::Result<Self> {
150 let socket = Socket::new(socket)?;
151 let ipv6 = socket.local_addr()?.is_ipv6();
152 let allow_mtud = !socket.may_fragment();
153
154 Ok(Self {
155 state: Mutex::new(EndpointState {
156 endpoint: quinn_proto::Endpoint::new(
157 Arc::new(config),
158 server_config.map(Arc::new),
159 allow_mtud,
160 None,
161 ),
162 worker: None,
163 connections: HashMap::default(),
164 close: None,
165 exit_on_idle: false,
166 incoming: VecDeque::new(),
167 incoming_wakers: VecDeque::new(),
168 }),
169 socket,
170 ipv6,
171 events: unbounded(),
172 done: AtomicWaker::new(),
173 })
174 }
175
176 fn connect(
177 &self,
178 remote: SocketAddr,
179 server_name: &str,
180 config: ClientConfig,
181 ) -> Result<Connecting, ConnectError> {
182 let mut state = self.state.lock();
183
184 if state.worker.is_none() {
185 return Err(ConnectError::EndpointStopping);
186 }
187 if remote.is_ipv6() && !self.ipv6 {
188 return Err(ConnectError::InvalidRemoteAddress(remote));
189 }
190 let remote = if self.ipv6 {
191 SocketAddr::V6(match remote {
192 SocketAddr::V4(addr) => {
193 SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
194 }
195 SocketAddr::V6(addr) => addr,
196 })
197 } else {
198 remote
199 };
200
201 let (handle, conn) = state
202 .endpoint
203 .connect(Instant::now(), config, remote, server_name)?;
204
205 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
206 }
207
208 fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
209 let socket = self.socket.clone();
210 compio_runtime::spawn(async move {
211 socket.send(buf, &transmit).await;
212 })
213 .detach();
214 }
215
216 pub(crate) fn accept(
217 &self,
218 incoming: quinn_proto::Incoming,
219 server_config: Option<ServerConfig>,
220 ) -> Result<Connecting, ConnectionError> {
221 let mut state = self.state.lock();
222 let mut resp_buf = Vec::new();
223 let now = Instant::now();
224 match state
225 .endpoint
226 .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
227 {
228 Ok((handle, conn)) => {
229 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
230 }
231 Err(err) => {
232 if let Some(transmit) = err.response {
233 self.respond(resp_buf, transmit);
234 }
235 Err(err.cause)
236 }
237 }
238 }
239
240 pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
241 let mut state = self.state.lock();
242 let mut resp_buf = Vec::new();
243 let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
244 self.respond(resp_buf, transmit);
245 }
246
247 #[allow(clippy::result_large_err)]
248 pub(crate) fn retry(
249 &self,
250 incoming: quinn_proto::Incoming,
251 ) -> Result<(), quinn_proto::RetryError> {
252 let mut state = self.state.lock();
253 let mut resp_buf = Vec::new();
254 let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
255 self.respond(resp_buf, transmit);
256 Ok(())
257 }
258
259 pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
260 let mut state = self.state.lock();
261 state.endpoint.ignore(incoming);
262 }
263
264 async fn run(&self) -> io::Result<()> {
265 let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
266
267 let mut recv_fut = pin!(
268 self.socket
269 .recv(Vec::with_capacity(
270 self.state
271 .lock()
272 .endpoint
273 .config()
274 .get_max_udp_payload_size()
275 .min(64 * 1024) as usize
276 * self.socket.max_gro_segments(),
277 ))
278 .fuse()
279 );
280
281 let mut event_stream = self.events.1.stream().ready_chunks(100);
282
283 loop {
284 let mut state = select! {
285 BufResult(res, recv_buf) = recv_fut => {
286 let mut state = self.state.lock();
287 match res {
288 Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
289 Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
290 #[cfg(windows)]
291 Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
292 Err(e) => break Err(e),
293 }
294 recv_fut.set(self.socket.recv(recv_buf).fuse());
295 state
296 },
297 events = event_stream.select_next_some() => {
298 let mut state = self.state.lock();
299 for (ch, event) in events {
300 state.handle_event(ch, event);
301 }
302 state
303 },
304 };
305
306 if state.exit_on_idle && state.is_idle() {
307 break Ok(());
308 }
309 if !state.incoming.is_empty() {
310 let n = state.incoming.len().min(state.incoming_wakers.len());
311 state.incoming_wakers.drain(..n).for_each(Waker::wake);
312 }
313 }
314 }
315}
316
317#[derive(Debug, Clone)]
318pub(crate) struct EndpointRef(Shared<EndpointInner>);
319
320impl EndpointRef {
321 unsafe fn try_unwrap_inner(&self) -> Option<EndpointInner> {
323 let ptr = unsafe { std::ptr::read(&self.0) };
324 match Shared::try_unwrap(ptr) {
325 Ok(inner) => Some(inner),
326 Err(ptr) => {
327 std::mem::forget(ptr);
328 None
329 }
330 }
331 }
332
333 async fn shutdown(self) -> io::Result<()> {
334 let (worker, idle) = {
335 let mut state = self.0.state.lock();
336 let idle = state.is_idle();
337 if !idle {
338 state.exit_on_idle = true;
339 }
340 (state.worker.take(), idle)
341 };
342 if let Some(worker) = worker {
343 if idle {
344 worker.cancel().await;
345 } else {
346 let _ = worker.await;
347 }
348 }
349
350 let this = ManuallyDrop::new(self);
351 let inner = future::poll_fn(move |cx| {
352 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
353 return Poll::Ready(inner);
354 }
355
356 this.done.register(cx.waker());
357
358 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
359 Poll::Ready(inner)
360 } else {
361 Poll::Pending
362 }
363 })
364 .await;
365
366 inner.socket.close().await
367 }
368}
369
370impl Drop for EndpointRef {
371 fn drop(&mut self) {
372 if Shared::strong_count(&self.0) == 2 {
373 self.0.done.wake();
376 self.0.state.lock().exit_on_idle = true;
378 }
379 }
380}
381
382impl Deref for EndpointRef {
383 type Target = EndpointInner;
384
385 fn deref(&self) -> &Self::Target {
386 &self.0
387 }
388}
389
390#[derive(Debug, Clone)]
392pub struct Endpoint {
393 inner: EndpointRef,
394 pub default_client_config: Option<ClientConfig>,
396}
397
398impl Endpoint {
399 pub fn new(
401 socket: UdpSocket,
402 config: EndpointConfig,
403 server_config: Option<ServerConfig>,
404 default_client_config: Option<ClientConfig>,
405 ) -> io::Result<Self> {
406 let inner = EndpointRef(Shared::new(EndpointInner::new(
407 socket,
408 config,
409 server_config,
410 )?));
411 let worker = compio_runtime::spawn({
412 let inner = inner.clone();
413 async move {
414 #[allow(unused)]
415 if let Err(e) = inner.run().await {
416 error!("I/O error: {}", e);
417 }
418 }
419 .in_current_span()
420 });
421 inner.state.lock().worker = Some(worker);
422 Ok(Self {
423 inner,
424 default_client_config,
425 })
426 }
427
428 #[cfg(rustls)]
443 pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
444 let socket = UdpSocket::bind(addr).await?;
446 Self::new(socket, EndpointConfig::default(), None, None)
447 }
448
449 #[cfg(rustls)]
458 pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
459 let socket = UdpSocket::bind(addr).await?;
460 Self::new(socket, EndpointConfig::default(), Some(config), None)
461 }
462
463 pub fn connect(
465 &self,
466 remote: SocketAddr,
467 server_name: &str,
468 config: Option<ClientConfig>,
469 ) -> Result<Connecting, ConnectError> {
470 let config = config
471 .or_else(|| self.default_client_config.clone())
472 .ok_or(ConnectError::NoDefaultClientConfig)?;
473
474 self.inner.connect(remote, server_name, config)
475 }
476
477 pub async fn wait_incoming(&self) -> Option<Incoming> {
486 future::poll_fn(|cx| self.inner.state.lock().poll_incoming(cx))
487 .await
488 .map(|incoming| Incoming::new(incoming, self.inner.clone()))
489 }
490
491 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
497 self.inner
498 .state
499 .lock()
500 .endpoint
501 .set_server_config(server_config.map(Arc::new))
502 }
503
504 pub fn local_addr(&self) -> io::Result<SocketAddr> {
506 self.inner.socket.local_addr()
507 }
508
509 pub fn open_connections(&self) -> usize {
511 self.inner.state.lock().endpoint.open_connections()
512 }
513
514 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
521 let reason = Bytes::copy_from_slice(reason);
522 let mut state = self.inner.state.lock();
523 if state.close.is_some() {
524 return;
525 }
526 state.close = Some((error_code, reason.clone()));
527 for conn in state.connections.values() {
528 let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
529 }
530 state.incoming_wakers.drain(..).for_each(Waker::wake);
531 }
532
533 pub async fn shutdown(self) -> io::Result<()> {
550 self.inner.shutdown().await
551 }
552}