1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 net::{IpAddr, SocketAddr},
5 pin::{Pin, pin},
6 task::{Context, Poll, Waker},
7 time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15 FutureExt, StreamExt,
16 future::{self, Fuse, FusedFuture, LocalBoxFuture},
17 select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22 ConnectionHandle, ConnectionStats, Dir, EndpointEvent, Side, StreamEvent, StreamId, VarInt,
23 congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29 RecvStream, SendStream, Socket,
30 sync::{
31 mutex_blocking::{Mutex, MutexGuard},
32 shared::Shared,
33 },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38 Close(VarInt, Bytes),
39 Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44 pub(crate) conn: quinn_proto::Connection,
45 pub(crate) error: Option<ConnectionError>,
46 connected: bool,
47 worker: Option<JoinHandle<()>>,
48 poller: Option<Waker>,
49 on_connected: Option<Waker>,
50 on_handshake_data: Option<Waker>,
51 datagram_received: VecDeque<Waker>,
52 datagrams_unblocked: VecDeque<Waker>,
53 stream_opened: [VecDeque<Waker>; 2],
54 stream_available: [VecDeque<Waker>; 2],
55 pub(crate) writable: HashMap<StreamId, Waker>,
56 pub(crate) readable: HashMap<StreamId, Waker>,
57 pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61 fn terminate(&mut self, reason: ConnectionError) {
62 self.error = Some(reason);
63 self.connected = false;
64
65 if let Some(waker) = self.on_handshake_data.take() {
66 waker.wake()
67 }
68 if let Some(waker) = self.on_connected.take() {
69 waker.wake()
70 }
71 self.datagram_received.drain(..).for_each(Waker::wake);
72 self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73 for e in &mut self.stream_opened {
74 e.drain(..).for_each(Waker::wake);
75 }
76 for e in &mut self.stream_available {
77 e.drain(..).for_each(Waker::wake);
78 }
79 wake_all_streams(&mut self.writable);
80 wake_all_streams(&mut self.readable);
81 wake_all_streams(&mut self.stopped);
82 }
83
84 fn close(&mut self, error_code: VarInt, reason: Bytes) {
85 self.conn.close(Instant::now(), error_code, reason);
86 self.terminate(ConnectionError::LocallyClosed);
87 self.wake();
88 }
89
90 pub(crate) fn wake(&mut self) {
91 if let Some(waker) = self.poller.take() {
92 waker.wake()
93 }
94 }
95
96 #[cfg(rustls)]
97 fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98 self.conn
99 .crypto_session()
100 .handshake_data()
101 .map(|data| data.downcast::<HandshakeData>().unwrap())
102 }
103
104 pub(crate) fn check_0rtt(&self) -> bool {
105 self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106 }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110 if let Some(waker) = wakers.remove(&stream) {
111 waker.wake();
112 }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116 wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121 state: Mutex<ConnectionState>,
122 handle: ConnectionHandle,
123 socket: Socket,
124 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125 events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129 if Shared::strong_count(this) == 2 {
130 this.state().close(0u32.into(), Bytes::new())
131 }
132}
133
134impl ConnectionInner {
135 fn new(
136 handle: ConnectionHandle,
137 conn: quinn_proto::Connection,
138 socket: Socket,
139 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140 events_rx: Receiver<ConnectionEvent>,
141 ) -> Self {
142 Self {
143 state: Mutex::new(ConnectionState {
144 conn,
145 connected: false,
146 error: None,
147 worker: None,
148 poller: None,
149 on_connected: None,
150 on_handshake_data: None,
151 datagram_received: VecDeque::new(),
152 datagrams_unblocked: VecDeque::new(),
153 stream_opened: [VecDeque::new(), VecDeque::new()],
154 stream_available: [VecDeque::new(), VecDeque::new()],
155 writable: HashMap::default(),
156 readable: HashMap::default(),
157 stopped: HashMap::default(),
158 }),
159 handle,
160 socket,
161 events_tx,
162 events_rx,
163 }
164 }
165
166 #[inline]
167 pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168 self.state.lock()
169 }
170
171 #[inline]
172 pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173 let state = self.state();
174 if let Some(error) = &state.error {
175 Err(error.clone())
176 } else {
177 Ok(state)
178 }
179 }
180
181 async fn run(&self) {
182 let mut poller = stream::poll_fn(|cx| {
183 let mut state = self.state();
184 let ready = state.poller.is_none();
185 match &state.poller {
186 Some(waker) if waker.will_wake(cx.waker()) => {}
187 _ => state.poller = Some(cx.waker().clone()),
188 };
189 if ready {
190 Poll::Ready(Some(()))
191 } else {
192 Poll::Pending
193 }
194 })
195 .fuse();
196
197 let mut timer = Timer::new();
198 let mut event_stream = self.events_rx.stream().ready_chunks(100);
199 let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200 let mut transmit_fut = pin!(Fuse::terminated());
201
202 loop {
203 let mut state = select! {
204 _ = poller.select_next_some() => self.state(),
205 _ = timer => {
206 timer.reset(None);
207 let mut state = self.state();
208 state.conn.handle_timeout(Instant::now());
209 state
210 }
211 events = event_stream.select_next_some() => {
212 let mut state = self.state();
213 for event in events {
214 match event {
215 ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216 ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217 }
218 }
219 state
220 },
221 buf = transmit_fut => {
222 let mut buf: Vec<_> = buf;
224 buf.clear();
225 send_buf = Some(buf);
226 self.state()
227 },
228 };
229
230 if let Some(mut buf) = send_buf.take() {
231 if let Some(transmit) = state.conn.poll_transmit(
232 Instant::now(),
233 self.socket.max_gso_segments(),
234 &mut buf,
235 ) {
236 transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237 } else {
238 send_buf = Some(buf);
239 }
240 }
241
242 timer.reset(state.conn.poll_timeout());
243
244 while let Some(event) = state.conn.poll_endpoint_events() {
245 let _ = self.events_tx.send((self.handle, event));
246 }
247
248 while let Some(event) = state.conn.poll() {
249 use quinn_proto::Event::*;
250 match event {
251 HandshakeDataReady => {
252 if let Some(waker) = state.on_handshake_data.take() {
253 waker.wake()
254 }
255 }
256 Connected => {
257 state.connected = true;
258 if let Some(waker) = state.on_connected.take() {
259 waker.wake()
260 }
261 if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262 wake_all_streams(&mut state.writable);
265 wake_all_streams(&mut state.readable);
266 wake_all_streams(&mut state.stopped);
267 }
268 }
269 ConnectionLost { reason } => state.terminate(reason.into()),
270 Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271 Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272 Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273 Stream(StreamEvent::Stopped { id, .. }) => {
274 wake_stream(id, &mut state.stopped);
275 wake_stream(id, &mut state.writable);
276 }
277 Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278 .drain(..)
279 .for_each(Waker::wake),
280 Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281 .drain(..)
282 .for_each(Waker::wake),
283 DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284 DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285 }
286 }
287
288 if state.conn.is_drained() {
289 break;
290 }
291 }
292 }
293}
294
295macro_rules! conn_fn {
296 () => {
297 pub fn side(&self) -> Side {
299 self.0.state().conn.side()
300 }
301
302 pub fn local_ip(&self) -> Option<IpAddr> {
311 self.0.state().conn.local_ip()
312 }
313
314 pub fn remote_address(&self) -> SocketAddr {
318 self.0.state().conn.remote_address()
319 }
320
321 pub fn rtt(&self) -> Duration {
323 self.0.state().conn.rtt()
324 }
325
326 pub fn stats(&self) -> ConnectionStats {
328 self.0.state().conn.stats()
329 }
330
331 pub fn congestion_state(&self) -> Box<dyn Controller> {
334 self.0.state().conn.congestion_state().clone_box()
335 }
336
337 pub fn peer_identity(
339 &self,
340 ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
341 self.0
342 .state()
343 .conn
344 .crypto_session()
345 .peer_identity()
346 .map(|v| v.downcast().unwrap())
347 }
348
349 pub fn stable_id(&self) -> usize {
354 Shared::as_ptr(&self.0) as usize
355 }
356
357 pub fn export_keying_material(
369 &self,
370 output: &mut [u8],
371 label: &[u8],
372 context: &[u8],
373 ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
374 self.0
375 .state()
376 .conn
377 .crypto_session()
378 .export_keying_material(output, label, context)
379 }
380 };
381}
382
383#[derive(Debug)]
385#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
386pub struct Connecting(Shared<ConnectionInner>);
387
388impl Connecting {
389 conn_fn!();
390
391 pub(crate) fn new(
392 handle: ConnectionHandle,
393 conn: quinn_proto::Connection,
394 socket: Socket,
395 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
396 events_rx: Receiver<ConnectionEvent>,
397 ) -> Self {
398 let inner = Shared::new(ConnectionInner::new(
399 handle, conn, socket, events_tx, events_rx,
400 ));
401 let worker = compio_runtime::spawn({
402 let inner = inner.clone();
403 async move { inner.run().await }.in_current_span()
404 });
405 inner.state().worker = Some(worker);
406 Self(inner)
407 }
408
409 #[cfg(rustls)]
411 pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
412 future::poll_fn(|cx| {
413 let mut state = self.0.try_state()?;
414 if let Some(data) = state.handshake_data() {
415 return Poll::Ready(Ok(data));
416 }
417
418 match &state.on_handshake_data {
419 Some(waker) if waker.will_wake(cx.waker()) => {}
420 _ => state.on_handshake_data = Some(cx.waker().clone()),
421 }
422
423 Poll::Pending
424 })
425 .await
426 }
427
428 pub fn into_0rtt(self) -> Result<Connection, Self> {
476 let is_ok = {
477 let state = self.0.state();
478 state.conn.has_0rtt() || state.conn.side().is_server()
479 };
480 if is_ok {
481 Ok(Connection(self.0.clone()))
482 } else {
483 Err(self)
484 }
485 }
486}
487
488impl Future for Connecting {
489 type Output = Result<Connection, ConnectionError>;
490
491 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
492 let mut state = self.0.try_state()?;
493
494 if state.connected {
495 return Poll::Ready(Ok(Connection(self.0.clone())));
496 }
497
498 match &state.on_connected {
499 Some(waker) if waker.will_wake(cx.waker()) => {}
500 _ => state.on_connected = Some(cx.waker().clone()),
501 }
502
503 Poll::Pending
504 }
505}
506
507impl Drop for Connecting {
508 fn drop(&mut self) {
509 implicit_close(&self.0)
510 }
511}
512
513#[derive(Debug, Clone)]
515pub struct Connection(Shared<ConnectionInner>);
516
517impl Connection {
518 conn_fn!();
519
520 pub fn force_key_update(&self) {
524 self.0.state().conn.force_key_update()
525 }
526
527 #[cfg(rustls)]
529 pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
530 Ok(self.0.try_state()?.handshake_data().unwrap())
531 }
532
533 pub fn max_datagram_size(&self) -> Option<usize> {
546 self.0.state().conn.datagrams().max_size()
547 }
548
549 pub fn datagram_send_buffer_space(&self) -> usize {
555 self.0.state().conn.datagrams().send_buffer_space()
556 }
557
558 pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
565 let mut state = self.0.state();
566 state.conn.set_max_concurrent_streams(Dir::Uni, count);
567 state.wake();
569 }
570
571 pub fn set_send_window(&self, send_window: u64) {
573 let mut state = self.0.state();
574 state.conn.set_send_window(send_window);
575 state.wake();
576 }
577
578 pub fn set_receive_window(&self, receive_window: VarInt) {
580 let mut state = self.0.state();
581 state.conn.set_receive_window(receive_window);
582 state.wake();
583 }
584
585 pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
592 let mut state = self.0.state();
593 state.conn.set_max_concurrent_streams(Dir::Bi, count);
594 state.wake();
596 }
597
598 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
635 self.0
636 .state()
637 .close(error_code, Bytes::copy_from_slice(reason));
638 }
639
640 pub async fn closed(&self) -> ConnectionError {
642 let worker = self.0.state().worker.take();
643 if let Some(worker) = worker {
644 let _ = worker.await;
645 }
646
647 self.0.try_state().unwrap_err()
648 }
649
650 pub fn close_reason(&self) -> Option<ConnectionError> {
654 self.0.try_state().err()
655 }
656
657 fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
658 let mut state = self.0.try_state()?;
659 if let Some(bytes) = state.conn.datagrams().recv() {
660 return Poll::Ready(Ok(bytes));
661 }
662 state.datagram_received.push_back(cx.waker().clone());
663 Poll::Pending
664 }
665
666 pub fn try_recv_datagram(&self) -> Result<Option<Bytes>, ConnectionError> {
669 let mut state = self.0.try_state()?;
670 Ok(state.conn.datagrams().recv())
671 }
672
673 pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
675 future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
676 }
677
678 fn try_send_datagram(
679 &self,
680 cx: Option<&mut Context>,
681 data: Bytes,
682 ) -> Result<(), Result<SendDatagramError, Bytes>> {
683 use quinn_proto::SendDatagramError::*;
684 let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
685 state
686 .conn
687 .datagrams()
688 .send(data, cx.is_none())
689 .map_err(|err| match err {
690 UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
691 Disabled => Ok(SendDatagramError::Disabled),
692 TooLarge => Ok(SendDatagramError::TooLarge),
693 Blocked(data) => {
694 state
695 .datagrams_unblocked
696 .push_back(cx.unwrap().waker().clone());
697 Err(data)
698 }
699 })?;
700 state.wake();
701 Ok(())
702 }
703
704 pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
710 self.try_send_datagram(None, data).map_err(Result::unwrap)
711 }
712
713 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
723 let mut data = Some(data);
724 future::poll_fn(
725 |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
726 Ok(()) => Poll::Ready(Ok(())),
727 Err(Ok(e)) => Poll::Ready(Err(e)),
728 Err(Err(b)) => {
729 data.replace(b);
730 Poll::Pending
731 }
732 },
733 )
734 .await
735 }
736
737 fn poll_open_stream(
738 &self,
739 cx: Option<&mut Context>,
740 dir: Dir,
741 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
742 let mut state = self.0.try_state()?;
743 if let Some(stream) = state.conn.streams().open(dir) {
744 Poll::Ready(Ok((
745 stream,
746 state.conn.side().is_client() && state.conn.is_handshaking(),
747 )))
748 } else {
749 if let Some(cx) = cx {
750 state.stream_available[dir as usize].push_back(cx.waker().clone());
751 }
752 Poll::Pending
753 }
754 }
755
756 pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
762 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
763 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
764 } else {
765 Err(OpenStreamError::StreamsExhausted)
766 }
767 }
768
769 pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
778 let (stream, is_0rtt) =
779 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
780 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
781 }
782
783 pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
789 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
790 Ok((
791 SendStream::new(self.0.clone(), stream, is_0rtt),
792 RecvStream::new(self.0.clone(), stream, is_0rtt),
793 ))
794 } else {
795 Err(OpenStreamError::StreamsExhausted)
796 }
797 }
798
799 pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
808 let (stream, is_0rtt) =
809 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
810 Ok((
811 SendStream::new(self.0.clone(), stream, is_0rtt),
812 RecvStream::new(self.0.clone(), stream, is_0rtt),
813 ))
814 }
815
816 fn poll_accept_stream(
817 &self,
818 cx: &mut Context,
819 dir: Dir,
820 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
821 let mut state = self.0.try_state()?;
822 if let Some(stream) = state.conn.streams().accept(dir) {
823 state.wake();
824 Poll::Ready(Ok((stream, state.conn.is_handshaking())))
825 } else {
826 state.stream_opened[dir as usize].push_back(cx.waker().clone());
827 Poll::Pending
828 }
829 }
830
831 pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
833 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
834 Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
835 }
836
837 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
849 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
850 Ok((
851 SendStream::new(self.0.clone(), stream, is_0rtt),
852 RecvStream::new(self.0.clone(), stream, is_0rtt),
853 ))
854 }
855
856 pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
861 future::poll_fn(|cx| {
862 let mut state = self.0.try_state()?;
863
864 if state.connected {
865 return Poll::Ready(Ok(state.conn.accepted_0rtt()));
866 }
867
868 match &state.on_connected {
869 Some(waker) if waker.will_wake(cx.waker()) => {}
870 _ => state.on_connected = Some(cx.waker().clone()),
871 }
872
873 Poll::Pending
874 })
875 .await
876 }
877}
878
879impl PartialEq for Connection {
880 fn eq(&self, other: &Self) -> bool {
881 Shared::ptr_eq(&self.0, &other.0)
882 }
883}
884
885impl Eq for Connection {}
886
887impl Drop for Connection {
888 fn drop(&mut self) {
889 implicit_close(&self.0)
890 }
891}
892
893struct Timer {
894 deadline: Option<Instant>,
895 fut: Fuse<LocalBoxFuture<'static, ()>>,
896}
897
898impl Timer {
899 fn new() -> Self {
900 Self {
901 deadline: None,
902 fut: Fuse::terminated(),
903 }
904 }
905
906 fn reset(&mut self, deadline: Option<Instant>) {
907 if let Some(deadline) = deadline {
908 if self.deadline.is_none() || self.deadline != Some(deadline) {
909 self.fut = compio_runtime::time::sleep_until(deadline)
910 .boxed_local()
911 .fuse();
912 }
913 } else {
914 self.fut = Fuse::terminated();
915 }
916 self.deadline = deadline;
917 }
918}
919
920impl Future for Timer {
921 type Output = ();
922
923 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
924 self.fut.poll_unpin(cx)
925 }
926}
927
928impl FusedFuture for Timer {
929 fn is_terminated(&self) -> bool {
930 self.fut.is_terminated()
931 }
932}
933
934#[derive(Debug, Error, Clone, PartialEq, Eq)]
936pub enum ConnectionError {
937 #[error("peer doesn't implement any supported version")]
939 VersionMismatch,
940 #[error(transparent)]
943 TransportError(#[from] quinn_proto::TransportError),
944 #[error("aborted by peer: {0}")]
946 ConnectionClosed(quinn_proto::ConnectionClose),
947 #[error("closed by peer: {0}")]
949 ApplicationClosed(quinn_proto::ApplicationClose),
950 #[error("reset by peer")]
953 Reset,
954 #[error("timed out")]
962 TimedOut,
963 #[error("closed")]
965 LocallyClosed,
966 #[error("CIDs exhausted")]
971 CidsExhausted,
972}
973
974impl From<quinn_proto::ConnectionError> for ConnectionError {
975 fn from(value: quinn_proto::ConnectionError) -> Self {
976 use quinn_proto::ConnectionError::*;
977
978 match value {
979 VersionMismatch => ConnectionError::VersionMismatch,
980 TransportError(e) => ConnectionError::TransportError(e),
981 ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
982 ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
983 Reset => ConnectionError::Reset,
984 TimedOut => ConnectionError::TimedOut,
985 LocallyClosed => ConnectionError::LocallyClosed,
986 CidsExhausted => ConnectionError::CidsExhausted,
987 }
988 }
989}
990
991#[derive(Debug, Error, Clone, Eq, PartialEq)]
993pub enum SendDatagramError {
994 #[error("datagrams not supported by peer")]
996 UnsupportedByPeer,
997 #[error("datagram support disabled")]
999 Disabled,
1000 #[error("datagram too large")]
1005 TooLarge,
1006 #[error("connection lost")]
1008 ConnectionLost(#[from] ConnectionError),
1009}
1010
1011#[derive(Debug, Error, Clone, Eq, PartialEq)]
1013pub enum OpenStreamError {
1014 #[error("connection lost")]
1016 ConnectionLost(#[from] ConnectionError),
1017 #[error("streams exhausted")]
1019 StreamsExhausted,
1020}
1021
1022#[cfg(feature = "h3")]
1023pub(crate) mod h3_impl {
1024 use std::sync::Arc;
1025
1026 use compio_buf::bytes::Buf;
1027 use futures_util::ready;
1028 use h3::{
1029 error::Code,
1030 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
1031 };
1032 use h3_datagram::{
1033 datagram::EncodedDatagram,
1034 quic_traits::{
1035 DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1036 },
1037 };
1038
1039 use super::*;
1040 use crate::send_stream::h3_impl::SendStream;
1041
1042 impl From<ConnectionError> for ConnectionErrorIncoming {
1043 fn from(e: ConnectionError) -> Self {
1044 use ConnectionError::*;
1045 match e {
1046 ApplicationClosed(e) => Self::ApplicationClose {
1047 error_code: e.error_code.into_inner(),
1048 },
1049 TimedOut => Self::Timeout,
1050
1051 e => Self::Undefined(Arc::new(e)),
1052 }
1053 }
1054 }
1055
1056 impl From<ConnectionError> for StreamErrorIncoming {
1057 fn from(e: ConnectionError) -> Self {
1058 Self::ConnectionErrorIncoming {
1059 connection_error: e.into(),
1060 }
1061 }
1062 }
1063
1064 impl From<SendDatagramError> for SendDatagramErrorIncoming {
1065 fn from(e: SendDatagramError) -> Self {
1066 use SendDatagramError::*;
1067 match e {
1068 UnsupportedByPeer | Disabled => Self::NotAvailable,
1069 TooLarge => Self::TooLarge,
1070 ConnectionLost(e) => Self::ConnectionError(e.into()),
1071 }
1072 }
1073 }
1074
1075 impl<B> SendDatagram<B> for Connection
1076 where
1077 B: Buf,
1078 {
1079 fn send_datagram<T: Into<EncodedDatagram<B>>>(
1080 &mut self,
1081 data: T,
1082 ) -> Result<(), SendDatagramErrorIncoming> {
1083 let mut buf: EncodedDatagram<B> = data.into();
1084 let buf = buf.copy_to_bytes(buf.remaining());
1085 Ok(Connection::send_datagram(self, buf)?)
1086 }
1087 }
1088
1089 impl RecvDatagram for Connection {
1090 type Buffer = Bytes;
1091
1092 fn poll_incoming_datagram(
1093 &mut self,
1094 cx: &mut core::task::Context<'_>,
1095 ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1096 Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1097 }
1098 }
1099
1100 impl<B: Buf> DatagramConnectionExt<B> for Connection {
1101 type RecvDatagramHandler = Self;
1102 type SendDatagramHandler = Self;
1103
1104 fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1105 self.clone()
1106 }
1107
1108 fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1109 self.clone()
1110 }
1111 }
1112
1113 pub struct BidiStream<B> {
1115 send: SendStream<B>,
1116 recv: RecvStream,
1117 }
1118
1119 impl<B> BidiStream<B> {
1120 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1121 Self {
1122 send: SendStream::new(conn.clone(), stream, is_0rtt),
1123 recv: RecvStream::new(conn, stream, is_0rtt),
1124 }
1125 }
1126 }
1127
1128 impl<B> quic::BidiStream<B> for BidiStream<B>
1129 where
1130 B: Buf,
1131 {
1132 type RecvStream = RecvStream;
1133 type SendStream = SendStream<B>;
1134
1135 fn split(self) -> (Self::SendStream, Self::RecvStream) {
1136 (self.send, self.recv)
1137 }
1138 }
1139
1140 impl<B> quic::RecvStream for BidiStream<B>
1141 where
1142 B: Buf,
1143 {
1144 type Buf = Bytes;
1145
1146 fn poll_data(
1147 &mut self,
1148 cx: &mut Context<'_>,
1149 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1150 self.recv.poll_data(cx)
1151 }
1152
1153 fn stop_sending(&mut self, error_code: u64) {
1154 self.recv.stop_sending(error_code)
1155 }
1156
1157 fn recv_id(&self) -> quic::StreamId {
1158 self.recv.recv_id()
1159 }
1160 }
1161
1162 impl<B> quic::SendStream<B> for BidiStream<B>
1163 where
1164 B: Buf,
1165 {
1166 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1167 self.send.poll_ready(cx)
1168 }
1169
1170 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1171 self.send.send_data(data)
1172 }
1173
1174 fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1175 self.send.poll_finish(cx)
1176 }
1177
1178 fn reset(&mut self, reset_code: u64) {
1179 self.send.reset(reset_code)
1180 }
1181
1182 fn send_id(&self) -> quic::StreamId {
1183 self.send.send_id()
1184 }
1185 }
1186
1187 impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1188 where
1189 B: Buf,
1190 {
1191 fn poll_send<D: Buf>(
1192 &mut self,
1193 cx: &mut Context<'_>,
1194 buf: &mut D,
1195 ) -> Poll<Result<usize, StreamErrorIncoming>> {
1196 self.send.poll_send(cx, buf)
1197 }
1198 }
1199
1200 #[derive(Clone)]
1202 pub struct OpenStreams(Connection);
1203
1204 impl<B> quic::OpenStreams<B> for OpenStreams
1205 where
1206 B: Buf,
1207 {
1208 type BidiStream = BidiStream<B>;
1209 type SendStream = SendStream<B>;
1210
1211 fn poll_open_bidi(
1212 &mut self,
1213 cx: &mut Context<'_>,
1214 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1215 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1216 Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1217 }
1218
1219 fn poll_open_send(
1220 &mut self,
1221 cx: &mut Context<'_>,
1222 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1223 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1224 Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1225 }
1226
1227 fn close(&mut self, code: Code, reason: &[u8]) {
1228 self.0
1229 .close(code.value().try_into().expect("invalid code"), reason)
1230 }
1231 }
1232
1233 impl<B> quic::OpenStreams<B> for Connection
1234 where
1235 B: Buf,
1236 {
1237 type BidiStream = BidiStream<B>;
1238 type SendStream = SendStream<B>;
1239
1240 fn poll_open_bidi(
1241 &mut self,
1242 cx: &mut Context<'_>,
1243 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1244 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1245 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1246 }
1247
1248 fn poll_open_send(
1249 &mut self,
1250 cx: &mut Context<'_>,
1251 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1252 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1253 Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1254 }
1255
1256 fn close(&mut self, code: Code, reason: &[u8]) {
1257 Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1258 }
1259 }
1260
1261 impl<B> quic::Connection<B> for Connection
1262 where
1263 B: Buf,
1264 {
1265 type OpenStreams = OpenStreams;
1266 type RecvStream = RecvStream;
1267
1268 fn poll_accept_recv(
1269 &mut self,
1270 cx: &mut std::task::Context<'_>,
1271 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1272 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1273 Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1274 }
1275
1276 fn poll_accept_bidi(
1277 &mut self,
1278 cx: &mut std::task::Context<'_>,
1279 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1280 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1281 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1282 }
1283
1284 fn opener(&self) -> Self::OpenStreams {
1285 OpenStreams(self.clone())
1286 }
1287 }
1288}