1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 net::{IpAddr, SocketAddr},
5 pin::{Pin, pin},
6 task::{Context, Poll, Waker},
7 time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15 FutureExt, StreamExt,
16 future::{self, Fuse, FusedFuture, LocalBoxFuture},
17 select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22 ConnectionHandle, ConnectionStats, Dir, EndpointEvent, Side, StreamEvent, StreamId, VarInt,
23 congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29 RecvStream, SendStream, Socket,
30 sync::{
31 mutex_blocking::{Mutex, MutexGuard},
32 shared::Shared,
33 },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38 Close(VarInt, Bytes),
39 Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44 pub(crate) conn: quinn_proto::Connection,
45 pub(crate) error: Option<ConnectionError>,
46 connected: bool,
47 worker: Option<JoinHandle<()>>,
48 poller: Option<Waker>,
49 on_connected: Option<Waker>,
50 on_handshake_data: Option<Waker>,
51 datagram_received: VecDeque<Waker>,
52 datagrams_unblocked: VecDeque<Waker>,
53 stream_opened: [VecDeque<Waker>; 2],
54 stream_available: [VecDeque<Waker>; 2],
55 pub(crate) writable: HashMap<StreamId, Waker>,
56 pub(crate) readable: HashMap<StreamId, Waker>,
57 pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61 fn terminate(&mut self, reason: ConnectionError) {
62 self.error = Some(reason);
63 self.connected = false;
64
65 if let Some(waker) = self.on_handshake_data.take() {
66 waker.wake()
67 }
68 if let Some(waker) = self.on_connected.take() {
69 waker.wake()
70 }
71 self.datagram_received.drain(..).for_each(Waker::wake);
72 self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73 for e in &mut self.stream_opened {
74 e.drain(..).for_each(Waker::wake);
75 }
76 for e in &mut self.stream_available {
77 e.drain(..).for_each(Waker::wake);
78 }
79 wake_all_streams(&mut self.writable);
80 wake_all_streams(&mut self.readable);
81 wake_all_streams(&mut self.stopped);
82 }
83
84 fn close(&mut self, error_code: VarInt, reason: Bytes) {
85 self.conn.close(Instant::now(), error_code, reason);
86 self.terminate(ConnectionError::LocallyClosed);
87 self.wake();
88 }
89
90 pub(crate) fn wake(&mut self) {
91 if let Some(waker) = self.poller.take() {
92 waker.wake()
93 }
94 }
95
96 #[cfg(rustls)]
97 fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98 self.conn
99 .crypto_session()
100 .handshake_data()
101 .map(|data| data.downcast::<HandshakeData>().unwrap())
102 }
103
104 pub(crate) fn check_0rtt(&self) -> bool {
105 self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106 }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110 if let Some(waker) = wakers.remove(&stream) {
111 waker.wake();
112 }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116 wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121 state: Mutex<ConnectionState>,
122 handle: ConnectionHandle,
123 socket: Socket,
124 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125 events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129 if Shared::strong_count(this) == 2 {
130 this.state().close(0u32.into(), Bytes::new())
131 }
132}
133
134impl ConnectionInner {
135 fn new(
136 handle: ConnectionHandle,
137 conn: quinn_proto::Connection,
138 socket: Socket,
139 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140 events_rx: Receiver<ConnectionEvent>,
141 ) -> Self {
142 Self {
143 state: Mutex::new(ConnectionState {
144 conn,
145 connected: false,
146 error: None,
147 worker: None,
148 poller: None,
149 on_connected: None,
150 on_handshake_data: None,
151 datagram_received: VecDeque::new(),
152 datagrams_unblocked: VecDeque::new(),
153 stream_opened: [VecDeque::new(), VecDeque::new()],
154 stream_available: [VecDeque::new(), VecDeque::new()],
155 writable: HashMap::default(),
156 readable: HashMap::default(),
157 stopped: HashMap::default(),
158 }),
159 handle,
160 socket,
161 events_tx,
162 events_rx,
163 }
164 }
165
166 #[inline]
167 pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168 self.state.lock()
169 }
170
171 #[inline]
172 pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173 let state = self.state();
174 if let Some(error) = &state.error {
175 Err(error.clone())
176 } else {
177 Ok(state)
178 }
179 }
180
181 async fn run(&self) {
182 let mut poller = stream::poll_fn(|cx| {
183 let mut state = self.state();
184 let ready = state.poller.is_none();
185 match &state.poller {
186 Some(waker) if waker.will_wake(cx.waker()) => {}
187 _ => state.poller = Some(cx.waker().clone()),
188 };
189 if ready {
190 Poll::Ready(Some(()))
191 } else {
192 Poll::Pending
193 }
194 })
195 .fuse();
196
197 let mut timer = Timer::new();
198 let mut event_stream = self.events_rx.stream().ready_chunks(100);
199 let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200 let mut transmit_fut = pin!(Fuse::terminated());
201
202 loop {
203 let mut state = select! {
204 _ = poller.select_next_some() => self.state(),
205 _ = timer => {
206 timer.reset(None);
207 let mut state = self.state();
208 state.conn.handle_timeout(Instant::now());
209 state
210 }
211 events = event_stream.select_next_some() => {
212 let mut state = self.state();
213 for event in events {
214 match event {
215 ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216 ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217 }
218 }
219 state
220 },
221 buf = transmit_fut => {
222 let mut buf: Vec<_> = buf;
224 buf.clear();
225 send_buf = Some(buf);
226 self.state()
227 },
228 };
229
230 if let Some(mut buf) = send_buf.take() {
231 if let Some(transmit) = state.conn.poll_transmit(
232 Instant::now(),
233 self.socket.max_gso_segments(),
234 &mut buf,
235 ) {
236 transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237 } else {
238 send_buf = Some(buf);
239 }
240 }
241
242 timer.reset(state.conn.poll_timeout());
243
244 while let Some(event) = state.conn.poll_endpoint_events() {
245 let _ = self.events_tx.send((self.handle, event));
246 }
247
248 while let Some(event) = state.conn.poll() {
249 use quinn_proto::Event::*;
250 match event {
251 HandshakeDataReady => {
252 if let Some(waker) = state.on_handshake_data.take() {
253 waker.wake()
254 }
255 }
256 Connected => {
257 state.connected = true;
258 if let Some(waker) = state.on_connected.take() {
259 waker.wake()
260 }
261 if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262 wake_all_streams(&mut state.writable);
265 wake_all_streams(&mut state.readable);
266 wake_all_streams(&mut state.stopped);
267 }
268 }
269 ConnectionLost { reason } => state.terminate(reason.into()),
270 Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271 Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272 Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273 Stream(StreamEvent::Stopped { id, .. }) => {
274 wake_stream(id, &mut state.stopped);
275 wake_stream(id, &mut state.writable);
276 }
277 Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278 .drain(..)
279 .for_each(Waker::wake),
280 Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281 .drain(..)
282 .for_each(Waker::wake),
283 DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284 DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285 }
286 }
287
288 if state.conn.is_drained() {
289 break;
290 }
291 }
292
293 if let Some(worker) = self.state().worker.take() {
295 worker.detach();
296 }
297 }
298}
299
300macro_rules! conn_fn {
301 () => {
302 pub fn side(&self) -> Side {
304 self.0.state().conn.side()
305 }
306
307 pub fn local_ip(&self) -> Option<IpAddr> {
316 self.0.state().conn.local_ip()
317 }
318
319 pub fn remote_address(&self) -> SocketAddr {
323 self.0.state().conn.remote_address()
324 }
325
326 pub fn rtt(&self) -> Duration {
328 self.0.state().conn.rtt()
329 }
330
331 pub fn stats(&self) -> ConnectionStats {
333 self.0.state().conn.stats()
334 }
335
336 pub fn congestion_state(&self) -> Box<dyn Controller> {
339 self.0.state().conn.congestion_state().clone_box()
340 }
341
342 pub fn peer_identity(
344 &self,
345 ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
346 self.0
347 .state()
348 .conn
349 .crypto_session()
350 .peer_identity()
351 .map(|v| v.downcast().unwrap())
352 }
353
354 pub fn stable_id(&self) -> usize {
359 Shared::as_ptr(&self.0) as usize
360 }
361
362 pub fn export_keying_material(
374 &self,
375 output: &mut [u8],
376 label: &[u8],
377 context: &[u8],
378 ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
379 self.0
380 .state()
381 .conn
382 .crypto_session()
383 .export_keying_material(output, label, context)
384 }
385 };
386}
387
388#[derive(Debug)]
390#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
391pub struct Connecting(Shared<ConnectionInner>);
392
393impl Connecting {
394 conn_fn!();
395
396 pub(crate) fn new(
397 handle: ConnectionHandle,
398 conn: quinn_proto::Connection,
399 socket: Socket,
400 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
401 events_rx: Receiver<ConnectionEvent>,
402 ) -> Self {
403 let inner = Shared::new(ConnectionInner::new(
404 handle, conn, socket, events_tx, events_rx,
405 ));
406 let worker = compio_runtime::spawn({
407 let inner = inner.clone();
408 async move { inner.run().await }.in_current_span()
409 });
410 inner.state().worker = Some(worker);
411 Self(inner)
412 }
413
414 #[cfg(rustls)]
416 pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
417 future::poll_fn(|cx| {
418 let mut state = self.0.try_state()?;
419 if let Some(data) = state.handshake_data() {
420 return Poll::Ready(Ok(data));
421 }
422
423 match &state.on_handshake_data {
424 Some(waker) if waker.will_wake(cx.waker()) => {}
425 _ => state.on_handshake_data = Some(cx.waker().clone()),
426 }
427
428 Poll::Pending
429 })
430 .await
431 }
432
433 pub fn into_0rtt(self) -> Result<Connection, Self> {
481 let is_ok = {
482 let state = self.0.state();
483 state.conn.has_0rtt() || state.conn.side().is_server()
484 };
485 if is_ok {
486 Ok(Connection(self.0.clone()))
487 } else {
488 Err(self)
489 }
490 }
491}
492
493impl Future for Connecting {
494 type Output = Result<Connection, ConnectionError>;
495
496 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
497 let mut state = self.0.try_state()?;
498
499 if state.connected {
500 return Poll::Ready(Ok(Connection(self.0.clone())));
501 }
502
503 match &state.on_connected {
504 Some(waker) if waker.will_wake(cx.waker()) => {}
505 _ => state.on_connected = Some(cx.waker().clone()),
506 }
507
508 Poll::Pending
509 }
510}
511
512impl Drop for Connecting {
513 fn drop(&mut self) {
514 implicit_close(&self.0)
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct Connection(Shared<ConnectionInner>);
521
522impl Connection {
523 conn_fn!();
524
525 pub fn force_key_update(&self) {
529 self.0.state().conn.force_key_update()
530 }
531
532 #[cfg(rustls)]
534 pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
535 Ok(self.0.try_state()?.handshake_data().unwrap())
536 }
537
538 pub fn max_datagram_size(&self) -> Option<usize> {
551 self.0.state().conn.datagrams().max_size()
552 }
553
554 pub fn datagram_send_buffer_space(&self) -> usize {
560 self.0.state().conn.datagrams().send_buffer_space()
561 }
562
563 pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
570 let mut state = self.0.state();
571 state.conn.set_max_concurrent_streams(Dir::Uni, count);
572 state.wake();
574 }
575
576 pub fn set_send_window(&self, send_window: u64) {
578 let mut state = self.0.state();
579 state.conn.set_send_window(send_window);
580 state.wake();
581 }
582
583 pub fn set_receive_window(&self, receive_window: VarInt) {
585 let mut state = self.0.state();
586 state.conn.set_receive_window(receive_window);
587 state.wake();
588 }
589
590 pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
597 let mut state = self.0.state();
598 state.conn.set_max_concurrent_streams(Dir::Bi, count);
599 state.wake();
601 }
602
603 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
640 self.0
641 .state()
642 .close(error_code, Bytes::copy_from_slice(reason));
643 }
644
645 pub async fn closed(&self) -> ConnectionError {
647 let worker = self.0.state().worker.take();
648 if let Some(worker) = worker {
649 let _ = worker.await;
650 }
651
652 self.0.try_state().unwrap_err()
653 }
654
655 pub fn close_reason(&self) -> Option<ConnectionError> {
659 self.0.try_state().err()
660 }
661
662 fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
663 let mut state = self.0.try_state()?;
664 if let Some(bytes) = state.conn.datagrams().recv() {
665 return Poll::Ready(Ok(bytes));
666 }
667 state.datagram_received.push_back(cx.waker().clone());
668 Poll::Pending
669 }
670
671 pub fn try_recv_datagram(&self) -> Result<Option<Bytes>, ConnectionError> {
674 let mut state = self.0.try_state()?;
675 Ok(state.conn.datagrams().recv())
676 }
677
678 pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
680 future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
681 }
682
683 fn try_send_datagram(
684 &self,
685 cx: Option<&mut Context>,
686 data: Bytes,
687 ) -> Result<(), Result<SendDatagramError, Bytes>> {
688 use quinn_proto::SendDatagramError::*;
689 let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
690 state
691 .conn
692 .datagrams()
693 .send(data, cx.is_none())
694 .map_err(|err| match err {
695 UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
696 Disabled => Ok(SendDatagramError::Disabled),
697 TooLarge => Ok(SendDatagramError::TooLarge),
698 Blocked(data) => {
699 state
700 .datagrams_unblocked
701 .push_back(cx.unwrap().waker().clone());
702 Err(data)
703 }
704 })?;
705 state.wake();
706 Ok(())
707 }
708
709 pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
715 self.try_send_datagram(None, data).map_err(Result::unwrap)
716 }
717
718 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
728 let mut data = Some(data);
729 future::poll_fn(
730 |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
731 Ok(()) => Poll::Ready(Ok(())),
732 Err(Ok(e)) => Poll::Ready(Err(e)),
733 Err(Err(b)) => {
734 data.replace(b);
735 Poll::Pending
736 }
737 },
738 )
739 .await
740 }
741
742 fn poll_open_stream(
743 &self,
744 cx: Option<&mut Context>,
745 dir: Dir,
746 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
747 let mut state = self.0.try_state()?;
748 if let Some(stream) = state.conn.streams().open(dir) {
749 Poll::Ready(Ok((
750 stream,
751 state.conn.side().is_client() && state.conn.is_handshaking(),
752 )))
753 } else {
754 if let Some(cx) = cx {
755 state.stream_available[dir as usize].push_back(cx.waker().clone());
756 }
757 Poll::Pending
758 }
759 }
760
761 pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
767 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
768 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
769 } else {
770 Err(OpenStreamError::StreamsExhausted)
771 }
772 }
773
774 pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
783 let (stream, is_0rtt) =
784 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
785 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
786 }
787
788 pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
794 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
795 Ok((
796 SendStream::new(self.0.clone(), stream, is_0rtt),
797 RecvStream::new(self.0.clone(), stream, is_0rtt),
798 ))
799 } else {
800 Err(OpenStreamError::StreamsExhausted)
801 }
802 }
803
804 pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
813 let (stream, is_0rtt) =
814 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
815 Ok((
816 SendStream::new(self.0.clone(), stream, is_0rtt),
817 RecvStream::new(self.0.clone(), stream, is_0rtt),
818 ))
819 }
820
821 fn poll_accept_stream(
822 &self,
823 cx: &mut Context,
824 dir: Dir,
825 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
826 let mut state = self.0.try_state()?;
827 if let Some(stream) = state.conn.streams().accept(dir) {
828 state.wake();
829 Poll::Ready(Ok((stream, state.conn.is_handshaking())))
830 } else {
831 state.stream_opened[dir as usize].push_back(cx.waker().clone());
832 Poll::Pending
833 }
834 }
835
836 pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
838 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
839 Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
840 }
841
842 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
854 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
855 Ok((
856 SendStream::new(self.0.clone(), stream, is_0rtt),
857 RecvStream::new(self.0.clone(), stream, is_0rtt),
858 ))
859 }
860
861 pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
866 future::poll_fn(|cx| {
867 let mut state = self.0.try_state()?;
868
869 if state.connected {
870 return Poll::Ready(Ok(state.conn.accepted_0rtt()));
871 }
872
873 match &state.on_connected {
874 Some(waker) if waker.will_wake(cx.waker()) => {}
875 _ => state.on_connected = Some(cx.waker().clone()),
876 }
877
878 Poll::Pending
879 })
880 .await
881 }
882}
883
884impl PartialEq for Connection {
885 fn eq(&self, other: &Self) -> bool {
886 Shared::ptr_eq(&self.0, &other.0)
887 }
888}
889
890impl Eq for Connection {}
891
892impl Drop for Connection {
893 fn drop(&mut self) {
894 implicit_close(&self.0)
895 }
896}
897
898struct Timer {
899 deadline: Option<Instant>,
900 fut: Fuse<LocalBoxFuture<'static, ()>>,
901}
902
903impl Timer {
904 fn new() -> Self {
905 Self {
906 deadline: None,
907 fut: Fuse::terminated(),
908 }
909 }
910
911 fn reset(&mut self, deadline: Option<Instant>) {
912 if let Some(deadline) = deadline {
913 if self.deadline.is_none() || self.deadline != Some(deadline) {
914 self.fut = compio_runtime::time::sleep_until(deadline)
915 .boxed_local()
916 .fuse();
917 }
918 } else {
919 self.fut = Fuse::terminated();
920 }
921 self.deadline = deadline;
922 }
923}
924
925impl Future for Timer {
926 type Output = ();
927
928 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
929 self.fut.poll_unpin(cx)
930 }
931}
932
933impl FusedFuture for Timer {
934 fn is_terminated(&self) -> bool {
935 self.fut.is_terminated()
936 }
937}
938
939#[derive(Debug, Error, Clone, PartialEq, Eq)]
941pub enum ConnectionError {
942 #[error("peer doesn't implement any supported version")]
944 VersionMismatch,
945 #[error(transparent)]
948 TransportError(#[from] quinn_proto::TransportError),
949 #[error("aborted by peer: {0}")]
951 ConnectionClosed(quinn_proto::ConnectionClose),
952 #[error("closed by peer: {0}")]
954 ApplicationClosed(quinn_proto::ApplicationClose),
955 #[error("reset by peer")]
958 Reset,
959 #[error("timed out")]
967 TimedOut,
968 #[error("closed")]
970 LocallyClosed,
971 #[error("CIDs exhausted")]
976 CidsExhausted,
977}
978
979impl From<quinn_proto::ConnectionError> for ConnectionError {
980 fn from(value: quinn_proto::ConnectionError) -> Self {
981 use quinn_proto::ConnectionError::*;
982
983 match value {
984 VersionMismatch => ConnectionError::VersionMismatch,
985 TransportError(e) => ConnectionError::TransportError(e),
986 ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
987 ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
988 Reset => ConnectionError::Reset,
989 TimedOut => ConnectionError::TimedOut,
990 LocallyClosed => ConnectionError::LocallyClosed,
991 CidsExhausted => ConnectionError::CidsExhausted,
992 }
993 }
994}
995
996#[derive(Debug, Error, Clone, Eq, PartialEq)]
998pub enum SendDatagramError {
999 #[error("datagrams not supported by peer")]
1001 UnsupportedByPeer,
1002 #[error("datagram support disabled")]
1004 Disabled,
1005 #[error("datagram too large")]
1010 TooLarge,
1011 #[error("connection lost")]
1013 ConnectionLost(#[from] ConnectionError),
1014}
1015
1016#[derive(Debug, Error, Clone, Eq, PartialEq)]
1018pub enum OpenStreamError {
1019 #[error("connection lost")]
1021 ConnectionLost(#[from] ConnectionError),
1022 #[error("streams exhausted")]
1024 StreamsExhausted,
1025}
1026
1027#[cfg(feature = "h3")]
1028pub(crate) mod h3_impl {
1029 use std::sync::Arc;
1030
1031 use compio_buf::bytes::Buf;
1032 use futures_util::ready;
1033 use h3::{
1034 error::Code,
1035 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
1036 };
1037 use h3_datagram::{
1038 datagram::EncodedDatagram,
1039 quic_traits::{
1040 DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1041 },
1042 };
1043
1044 use super::*;
1045 use crate::send_stream::h3_impl::SendStream;
1046
1047 impl From<ConnectionError> for ConnectionErrorIncoming {
1048 fn from(e: ConnectionError) -> Self {
1049 use ConnectionError::*;
1050 match e {
1051 ApplicationClosed(e) => Self::ApplicationClose {
1052 error_code: e.error_code.into_inner(),
1053 },
1054 TimedOut => Self::Timeout,
1055
1056 e => Self::Undefined(Arc::new(e)),
1057 }
1058 }
1059 }
1060
1061 impl From<ConnectionError> for StreamErrorIncoming {
1062 fn from(e: ConnectionError) -> Self {
1063 Self::ConnectionErrorIncoming {
1064 connection_error: e.into(),
1065 }
1066 }
1067 }
1068
1069 impl From<SendDatagramError> for SendDatagramErrorIncoming {
1070 fn from(e: SendDatagramError) -> Self {
1071 use SendDatagramError::*;
1072 match e {
1073 UnsupportedByPeer | Disabled => Self::NotAvailable,
1074 TooLarge => Self::TooLarge,
1075 ConnectionLost(e) => Self::ConnectionError(e.into()),
1076 }
1077 }
1078 }
1079
1080 impl<B> SendDatagram<B> for Connection
1081 where
1082 B: Buf,
1083 {
1084 fn send_datagram<T: Into<EncodedDatagram<B>>>(
1085 &mut self,
1086 data: T,
1087 ) -> Result<(), SendDatagramErrorIncoming> {
1088 let mut buf: EncodedDatagram<B> = data.into();
1089 let buf = buf.copy_to_bytes(buf.remaining());
1090 Ok(Connection::send_datagram(self, buf)?)
1091 }
1092 }
1093
1094 impl RecvDatagram for Connection {
1095 type Buffer = Bytes;
1096
1097 fn poll_incoming_datagram(
1098 &mut self,
1099 cx: &mut core::task::Context<'_>,
1100 ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1101 Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1102 }
1103 }
1104
1105 impl<B: Buf> DatagramConnectionExt<B> for Connection {
1106 type RecvDatagramHandler = Self;
1107 type SendDatagramHandler = Self;
1108
1109 fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1110 self.clone()
1111 }
1112
1113 fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1114 self.clone()
1115 }
1116 }
1117
1118 pub struct BidiStream<B> {
1120 send: SendStream<B>,
1121 recv: RecvStream,
1122 }
1123
1124 impl<B> BidiStream<B> {
1125 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1126 Self {
1127 send: SendStream::new(conn.clone(), stream, is_0rtt),
1128 recv: RecvStream::new(conn, stream, is_0rtt),
1129 }
1130 }
1131 }
1132
1133 impl<B> quic::BidiStream<B> for BidiStream<B>
1134 where
1135 B: Buf,
1136 {
1137 type RecvStream = RecvStream;
1138 type SendStream = SendStream<B>;
1139
1140 fn split(self) -> (Self::SendStream, Self::RecvStream) {
1141 (self.send, self.recv)
1142 }
1143 }
1144
1145 impl<B> quic::RecvStream for BidiStream<B>
1146 where
1147 B: Buf,
1148 {
1149 type Buf = Bytes;
1150
1151 fn poll_data(
1152 &mut self,
1153 cx: &mut Context<'_>,
1154 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1155 self.recv.poll_data(cx)
1156 }
1157
1158 fn stop_sending(&mut self, error_code: u64) {
1159 self.recv.stop_sending(error_code)
1160 }
1161
1162 fn recv_id(&self) -> quic::StreamId {
1163 self.recv.recv_id()
1164 }
1165 }
1166
1167 impl<B> quic::SendStream<B> for BidiStream<B>
1168 where
1169 B: Buf,
1170 {
1171 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1172 self.send.poll_ready(cx)
1173 }
1174
1175 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1176 self.send.send_data(data)
1177 }
1178
1179 fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1180 self.send.poll_finish(cx)
1181 }
1182
1183 fn reset(&mut self, reset_code: u64) {
1184 self.send.reset(reset_code)
1185 }
1186
1187 fn send_id(&self) -> quic::StreamId {
1188 self.send.send_id()
1189 }
1190 }
1191
1192 impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1193 where
1194 B: Buf,
1195 {
1196 fn poll_send<D: Buf>(
1197 &mut self,
1198 cx: &mut Context<'_>,
1199 buf: &mut D,
1200 ) -> Poll<Result<usize, StreamErrorIncoming>> {
1201 self.send.poll_send(cx, buf)
1202 }
1203 }
1204
1205 #[derive(Clone)]
1207 pub struct OpenStreams(Connection);
1208
1209 impl<B> quic::OpenStreams<B> for OpenStreams
1210 where
1211 B: Buf,
1212 {
1213 type BidiStream = BidiStream<B>;
1214 type SendStream = SendStream<B>;
1215
1216 fn poll_open_bidi(
1217 &mut self,
1218 cx: &mut Context<'_>,
1219 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1220 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1221 Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1222 }
1223
1224 fn poll_open_send(
1225 &mut self,
1226 cx: &mut Context<'_>,
1227 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1228 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1229 Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1230 }
1231
1232 fn close(&mut self, code: Code, reason: &[u8]) {
1233 self.0
1234 .close(code.value().try_into().expect("invalid code"), reason)
1235 }
1236 }
1237
1238 impl<B> quic::OpenStreams<B> for Connection
1239 where
1240 B: Buf,
1241 {
1242 type BidiStream = BidiStream<B>;
1243 type SendStream = SendStream<B>;
1244
1245 fn poll_open_bidi(
1246 &mut self,
1247 cx: &mut Context<'_>,
1248 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1249 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1250 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1251 }
1252
1253 fn poll_open_send(
1254 &mut self,
1255 cx: &mut Context<'_>,
1256 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1257 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1258 Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1259 }
1260
1261 fn close(&mut self, code: Code, reason: &[u8]) {
1262 Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1263 }
1264 }
1265
1266 impl<B> quic::Connection<B> for Connection
1267 where
1268 B: Buf,
1269 {
1270 type OpenStreams = OpenStreams;
1271 type RecvStream = RecvStream;
1272
1273 fn poll_accept_recv(
1274 &mut self,
1275 cx: &mut std::task::Context<'_>,
1276 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1277 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1278 Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1279 }
1280
1281 fn poll_accept_bidi(
1282 &mut self,
1283 cx: &mut std::task::Context<'_>,
1284 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1285 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1286 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1287 }
1288
1289 fn opener(&self) -> Self::OpenStreams {
1290 OpenStreams(self.clone())
1291 }
1292 }
1293}