1use std::{
2 collections::VecDeque,
3 fmt::Debug,
4 net::{IpAddr, SocketAddr},
5 pin::{Pin, pin},
6 task::{Context, Poll, Waker},
7 time::{Duration, Instant},
8};
9
10use compio_buf::bytes::Bytes;
11use compio_log::Instrument;
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15 Future, FutureExt, StreamExt,
16 future::{self, Fuse, FusedFuture, LocalBoxFuture},
17 select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22 ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt,
23 congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{
29 RecvStream, SendStream, Socket,
30 sync::{
31 mutex_blocking::{Mutex, MutexGuard},
32 shared::Shared,
33 },
34};
35
36#[derive(Debug)]
37pub(crate) enum ConnectionEvent {
38 Close(VarInt, Bytes),
39 Proto(quinn_proto::ConnectionEvent),
40}
41
42#[derive(Debug)]
43pub(crate) struct ConnectionState {
44 pub(crate) conn: quinn_proto::Connection,
45 pub(crate) error: Option<ConnectionError>,
46 connected: bool,
47 worker: Option<JoinHandle<()>>,
48 poller: Option<Waker>,
49 on_connected: Option<Waker>,
50 on_handshake_data: Option<Waker>,
51 datagram_received: VecDeque<Waker>,
52 datagrams_unblocked: VecDeque<Waker>,
53 stream_opened: [VecDeque<Waker>; 2],
54 stream_available: [VecDeque<Waker>; 2],
55 pub(crate) writable: HashMap<StreamId, Waker>,
56 pub(crate) readable: HashMap<StreamId, Waker>,
57 pub(crate) stopped: HashMap<StreamId, Waker>,
58}
59
60impl ConnectionState {
61 fn terminate(&mut self, reason: ConnectionError) {
62 self.error = Some(reason);
63 self.connected = false;
64
65 if let Some(waker) = self.on_handshake_data.take() {
66 waker.wake()
67 }
68 if let Some(waker) = self.on_connected.take() {
69 waker.wake()
70 }
71 self.datagram_received.drain(..).for_each(Waker::wake);
72 self.datagrams_unblocked.drain(..).for_each(Waker::wake);
73 for e in &mut self.stream_opened {
74 e.drain(..).for_each(Waker::wake);
75 }
76 for e in &mut self.stream_available {
77 e.drain(..).for_each(Waker::wake);
78 }
79 wake_all_streams(&mut self.writable);
80 wake_all_streams(&mut self.readable);
81 wake_all_streams(&mut self.stopped);
82 }
83
84 fn close(&mut self, error_code: VarInt, reason: Bytes) {
85 self.conn.close(Instant::now(), error_code, reason);
86 self.terminate(ConnectionError::LocallyClosed);
87 self.wake();
88 }
89
90 pub(crate) fn wake(&mut self) {
91 if let Some(waker) = self.poller.take() {
92 waker.wake()
93 }
94 }
95
96 #[cfg(rustls)]
97 fn handshake_data(&self) -> Option<Box<HandshakeData>> {
98 self.conn
99 .crypto_session()
100 .handshake_data()
101 .map(|data| data.downcast::<HandshakeData>().unwrap())
102 }
103
104 pub(crate) fn check_0rtt(&self) -> bool {
105 self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106 }
107}
108
109fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
110 if let Some(waker) = wakers.remove(&stream) {
111 waker.wake();
112 }
113}
114
115fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116 wakers.drain().for_each(|(_, waker)| waker.wake())
117}
118
119#[derive(Debug)]
120pub(crate) struct ConnectionInner {
121 state: Mutex<ConnectionState>,
122 handle: ConnectionHandle,
123 socket: Socket,
124 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
125 events_rx: Receiver<ConnectionEvent>,
126}
127
128fn implicit_close(this: &Shared<ConnectionInner>) {
129 if Shared::strong_count(this) == 2 {
130 this.state().close(0u32.into(), Bytes::new())
131 }
132}
133
134impl ConnectionInner {
135 fn new(
136 handle: ConnectionHandle,
137 conn: quinn_proto::Connection,
138 socket: Socket,
139 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
140 events_rx: Receiver<ConnectionEvent>,
141 ) -> Self {
142 Self {
143 state: Mutex::new(ConnectionState {
144 conn,
145 connected: false,
146 error: None,
147 worker: None,
148 poller: None,
149 on_connected: None,
150 on_handshake_data: None,
151 datagram_received: VecDeque::new(),
152 datagrams_unblocked: VecDeque::new(),
153 stream_opened: [VecDeque::new(), VecDeque::new()],
154 stream_available: [VecDeque::new(), VecDeque::new()],
155 writable: HashMap::default(),
156 readable: HashMap::default(),
157 stopped: HashMap::default(),
158 }),
159 handle,
160 socket,
161 events_tx,
162 events_rx,
163 }
164 }
165
166 #[inline]
167 pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
168 self.state.lock()
169 }
170
171 #[inline]
172 pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
173 let state = self.state();
174 if let Some(error) = &state.error {
175 Err(error.clone())
176 } else {
177 Ok(state)
178 }
179 }
180
181 async fn run(&self) {
182 let mut poller = stream::poll_fn(|cx| {
183 let mut state = self.state();
184 let ready = state.poller.is_none();
185 match &state.poller {
186 Some(waker) if waker.will_wake(cx.waker()) => {}
187 _ => state.poller = Some(cx.waker().clone()),
188 };
189 if ready {
190 Poll::Ready(Some(()))
191 } else {
192 Poll::Pending
193 }
194 })
195 .fuse();
196
197 let mut timer = Timer::new();
198 let mut event_stream = self.events_rx.stream().ready_chunks(100);
199 let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
200 let mut transmit_fut = pin!(Fuse::terminated());
201
202 loop {
203 let mut state = select! {
204 _ = poller.select_next_some() => self.state(),
205 _ = timer => {
206 timer.reset(None);
207 let mut state = self.state();
208 state.conn.handle_timeout(Instant::now());
209 state
210 }
211 events = event_stream.select_next_some() => {
212 let mut state = self.state();
213 for event in events {
214 match event {
215 ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
216 ConnectionEvent::Proto(event) => state.conn.handle_event(event),
217 }
218 }
219 state
220 },
221 buf = transmit_fut => {
222 let mut buf: Vec<_> = buf;
224 buf.clear();
225 send_buf = Some(buf);
226 self.state()
227 },
228 };
229
230 if let Some(mut buf) = send_buf.take() {
231 if let Some(transmit) = state.conn.poll_transmit(
232 Instant::now(),
233 self.socket.max_gso_segments(),
234 &mut buf,
235 ) {
236 transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
237 } else {
238 send_buf = Some(buf);
239 }
240 }
241
242 timer.reset(state.conn.poll_timeout());
243
244 while let Some(event) = state.conn.poll_endpoint_events() {
245 let _ = self.events_tx.send((self.handle, event));
246 }
247
248 while let Some(event) = state.conn.poll() {
249 use quinn_proto::Event::*;
250 match event {
251 HandshakeDataReady => {
252 if let Some(waker) = state.on_handshake_data.take() {
253 waker.wake()
254 }
255 }
256 Connected => {
257 state.connected = true;
258 if let Some(waker) = state.on_connected.take() {
259 waker.wake()
260 }
261 if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
262 wake_all_streams(&mut state.writable);
265 wake_all_streams(&mut state.readable);
266 wake_all_streams(&mut state.stopped);
267 }
268 }
269 ConnectionLost { reason } => state.terminate(reason.into()),
270 Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
271 Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
272 Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
273 Stream(StreamEvent::Stopped { id, .. }) => {
274 wake_stream(id, &mut state.stopped);
275 wake_stream(id, &mut state.writable);
276 }
277 Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
278 .drain(..)
279 .for_each(Waker::wake),
280 Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
281 .drain(..)
282 .for_each(Waker::wake),
283 DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
284 DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285 }
286 }
287
288 if state.conn.is_drained() {
289 break;
290 }
291 }
292 }
293}
294
295macro_rules! conn_fn {
296 () => {
297 pub fn local_ip(&self) -> Option<IpAddr> {
306 self.0.state().conn.local_ip()
307 }
308
309 pub fn remote_address(&self) -> SocketAddr {
313 self.0.state().conn.remote_address()
314 }
315
316 pub fn rtt(&self) -> Duration {
318 self.0.state().conn.rtt()
319 }
320
321 pub fn stats(&self) -> ConnectionStats {
323 self.0.state().conn.stats()
324 }
325
326 pub fn congestion_state(&self) -> Box<dyn Controller> {
329 self.0.state().conn.congestion_state().clone_box()
330 }
331
332 pub fn peer_identity(
334 &self,
335 ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
336 self.0
337 .state()
338 .conn
339 .crypto_session()
340 .peer_identity()
341 .map(|v| v.downcast().unwrap())
342 }
343
344 pub fn export_keying_material(
356 &self,
357 output: &mut [u8],
358 label: &[u8],
359 context: &[u8],
360 ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
361 self.0
362 .state()
363 .conn
364 .crypto_session()
365 .export_keying_material(output, label, context)
366 }
367 };
368}
369
370#[derive(Debug)]
372#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
373pub struct Connecting(Shared<ConnectionInner>);
374
375impl Connecting {
376 conn_fn!();
377
378 pub(crate) fn new(
379 handle: ConnectionHandle,
380 conn: quinn_proto::Connection,
381 socket: Socket,
382 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
383 events_rx: Receiver<ConnectionEvent>,
384 ) -> Self {
385 let inner = Shared::new(ConnectionInner::new(
386 handle, conn, socket, events_tx, events_rx,
387 ));
388 let worker = compio_runtime::spawn({
389 let inner = inner.clone();
390 async move { inner.run().await }.in_current_span()
391 });
392 inner.state().worker = Some(worker);
393 Self(inner)
394 }
395
396 #[cfg(rustls)]
398 pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
399 future::poll_fn(|cx| {
400 let mut state = self.0.try_state()?;
401 if let Some(data) = state.handshake_data() {
402 return Poll::Ready(Ok(data));
403 }
404
405 match &state.on_handshake_data {
406 Some(waker) if waker.will_wake(cx.waker()) => {}
407 _ => state.on_handshake_data = Some(cx.waker().clone()),
408 }
409
410 Poll::Pending
411 })
412 .await
413 }
414
415 pub fn into_0rtt(self) -> Result<Connection, Self> {
463 let is_ok = {
464 let state = self.0.state();
465 state.conn.has_0rtt() || state.conn.side().is_server()
466 };
467 if is_ok {
468 Ok(Connection(self.0.clone()))
469 } else {
470 Err(self)
471 }
472 }
473}
474
475impl Future for Connecting {
476 type Output = Result<Connection, ConnectionError>;
477
478 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479 let mut state = self.0.try_state()?;
480
481 if state.connected {
482 return Poll::Ready(Ok(Connection(self.0.clone())));
483 }
484
485 match &state.on_connected {
486 Some(waker) if waker.will_wake(cx.waker()) => {}
487 _ => state.on_connected = Some(cx.waker().clone()),
488 }
489
490 Poll::Pending
491 }
492}
493
494impl Drop for Connecting {
495 fn drop(&mut self) {
496 implicit_close(&self.0)
497 }
498}
499
500#[derive(Debug, Clone)]
502pub struct Connection(Shared<ConnectionInner>);
503
504impl Connection {
505 conn_fn!();
506
507 #[cfg(rustls)]
509 pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
510 Ok(self.0.try_state()?.handshake_data().unwrap())
511 }
512
513 pub fn max_datagram_size(&self) -> Option<usize> {
526 self.0.state().conn.datagrams().max_size()
527 }
528
529 pub fn datagram_send_buffer_space(&self) -> usize {
535 self.0.state().conn.datagrams().send_buffer_space()
536 }
537
538 pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
545 let mut state = self.0.state();
546 state.conn.set_max_concurrent_streams(Dir::Uni, count);
547 state.wake();
549 }
550
551 pub fn set_receive_window(&self, receive_window: VarInt) {
553 let mut state = self.0.state();
554 state.conn.set_receive_window(receive_window);
555 state.wake();
556 }
557
558 pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
565 let mut state = self.0.state();
566 state.conn.set_max_concurrent_streams(Dir::Bi, count);
567 state.wake();
569 }
570
571 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
608 self.0
609 .state()
610 .close(error_code, Bytes::copy_from_slice(reason));
611 }
612
613 pub async fn closed(&self) -> ConnectionError {
615 let worker = self.0.state().worker.take();
616 if let Some(worker) = worker {
617 let _ = worker.await;
618 }
619
620 self.0.try_state().unwrap_err()
621 }
622
623 pub fn close_reason(&self) -> Option<ConnectionError> {
627 self.0.try_state().err()
628 }
629
630 fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
631 let mut state = self.0.try_state()?;
632 if let Some(bytes) = state.conn.datagrams().recv() {
633 return Poll::Ready(Ok(bytes));
634 }
635 state.datagram_received.push_back(cx.waker().clone());
636 Poll::Pending
637 }
638
639 pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
641 future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
642 }
643
644 fn try_send_datagram(
645 &self,
646 cx: Option<&mut Context>,
647 data: Bytes,
648 ) -> Result<(), Result<SendDatagramError, Bytes>> {
649 use quinn_proto::SendDatagramError::*;
650 let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
651 state
652 .conn
653 .datagrams()
654 .send(data, cx.is_none())
655 .map_err(|err| match err {
656 UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
657 Disabled => Ok(SendDatagramError::Disabled),
658 TooLarge => Ok(SendDatagramError::TooLarge),
659 Blocked(data) => {
660 state
661 .datagrams_unblocked
662 .push_back(cx.unwrap().waker().clone());
663 Err(data)
664 }
665 })?;
666 state.wake();
667 Ok(())
668 }
669
670 pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
676 self.try_send_datagram(None, data).map_err(Result::unwrap)
677 }
678
679 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
689 let mut data = Some(data);
690 future::poll_fn(
691 |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
692 Ok(()) => Poll::Ready(Ok(())),
693 Err(Ok(e)) => Poll::Ready(Err(e)),
694 Err(Err(b)) => {
695 data.replace(b);
696 Poll::Pending
697 }
698 },
699 )
700 .await
701 }
702
703 fn poll_open_stream(
704 &self,
705 cx: Option<&mut Context>,
706 dir: Dir,
707 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
708 let mut state = self.0.try_state()?;
709 if let Some(stream) = state.conn.streams().open(dir) {
710 Poll::Ready(Ok((
711 stream,
712 state.conn.side().is_client() && state.conn.is_handshaking(),
713 )))
714 } else {
715 if let Some(cx) = cx {
716 state.stream_available[dir as usize].push_back(cx.waker().clone());
717 }
718 Poll::Pending
719 }
720 }
721
722 pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
728 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
729 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
730 } else {
731 Err(OpenStreamError::StreamsExhausted)
732 }
733 }
734
735 pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
744 let (stream, is_0rtt) =
745 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
746 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
747 }
748
749 pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
755 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
756 Ok((
757 SendStream::new(self.0.clone(), stream, is_0rtt),
758 RecvStream::new(self.0.clone(), stream, is_0rtt),
759 ))
760 } else {
761 Err(OpenStreamError::StreamsExhausted)
762 }
763 }
764
765 pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
774 let (stream, is_0rtt) =
775 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
776 Ok((
777 SendStream::new(self.0.clone(), stream, is_0rtt),
778 RecvStream::new(self.0.clone(), stream, is_0rtt),
779 ))
780 }
781
782 fn poll_accept_stream(
783 &self,
784 cx: &mut Context,
785 dir: Dir,
786 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
787 let mut state = self.0.try_state()?;
788 if let Some(stream) = state.conn.streams().accept(dir) {
789 state.wake();
790 Poll::Ready(Ok((stream, state.conn.is_handshaking())))
791 } else {
792 state.stream_opened[dir as usize].push_back(cx.waker().clone());
793 Poll::Pending
794 }
795 }
796
797 pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
799 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
800 Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
801 }
802
803 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
815 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
816 Ok((
817 SendStream::new(self.0.clone(), stream, is_0rtt),
818 RecvStream::new(self.0.clone(), stream, is_0rtt),
819 ))
820 }
821
822 pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
827 future::poll_fn(|cx| {
828 let mut state = self.0.try_state()?;
829
830 if state.connected {
831 return Poll::Ready(Ok(state.conn.accepted_0rtt()));
832 }
833
834 match &state.on_connected {
835 Some(waker) if waker.will_wake(cx.waker()) => {}
836 _ => state.on_connected = Some(cx.waker().clone()),
837 }
838
839 Poll::Pending
840 })
841 .await
842 }
843}
844
845impl PartialEq for Connection {
846 fn eq(&self, other: &Self) -> bool {
847 Shared::ptr_eq(&self.0, &other.0)
848 }
849}
850
851impl Eq for Connection {}
852
853impl Drop for Connection {
854 fn drop(&mut self) {
855 implicit_close(&self.0)
856 }
857}
858
859struct Timer {
860 deadline: Option<Instant>,
861 fut: Fuse<LocalBoxFuture<'static, ()>>,
862}
863
864impl Timer {
865 fn new() -> Self {
866 Self {
867 deadline: None,
868 fut: Fuse::terminated(),
869 }
870 }
871
872 fn reset(&mut self, deadline: Option<Instant>) {
873 if let Some(deadline) = deadline {
874 if self.deadline.is_none() || self.deadline != Some(deadline) {
875 self.fut = compio_runtime::time::sleep_until(deadline)
876 .boxed_local()
877 .fuse();
878 }
879 } else {
880 self.fut = Fuse::terminated();
881 }
882 self.deadline = deadline;
883 }
884}
885
886impl Future for Timer {
887 type Output = ();
888
889 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
890 self.fut.poll_unpin(cx)
891 }
892}
893
894impl FusedFuture for Timer {
895 fn is_terminated(&self) -> bool {
896 self.fut.is_terminated()
897 }
898}
899
900#[derive(Debug, Error, Clone, PartialEq, Eq)]
902pub enum ConnectionError {
903 #[error("peer doesn't implement any supported version")]
905 VersionMismatch,
906 #[error(transparent)]
909 TransportError(#[from] quinn_proto::TransportError),
910 #[error("aborted by peer: {0}")]
912 ConnectionClosed(quinn_proto::ConnectionClose),
913 #[error("closed by peer: {0}")]
915 ApplicationClosed(quinn_proto::ApplicationClose),
916 #[error("reset by peer")]
919 Reset,
920 #[error("timed out")]
928 TimedOut,
929 #[error("closed")]
931 LocallyClosed,
932 #[error("CIDs exhausted")]
937 CidsExhausted,
938}
939
940impl From<quinn_proto::ConnectionError> for ConnectionError {
941 fn from(value: quinn_proto::ConnectionError) -> Self {
942 use quinn_proto::ConnectionError::*;
943
944 match value {
945 VersionMismatch => ConnectionError::VersionMismatch,
946 TransportError(e) => ConnectionError::TransportError(e),
947 ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
948 ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
949 Reset => ConnectionError::Reset,
950 TimedOut => ConnectionError::TimedOut,
951 LocallyClosed => ConnectionError::LocallyClosed,
952 CidsExhausted => ConnectionError::CidsExhausted,
953 }
954 }
955}
956
957#[derive(Debug, Error, Clone, Eq, PartialEq)]
959pub enum SendDatagramError {
960 #[error("datagrams not supported by peer")]
962 UnsupportedByPeer,
963 #[error("datagram support disabled")]
965 Disabled,
966 #[error("datagram too large")]
971 TooLarge,
972 #[error("connection lost")]
974 ConnectionLost(#[from] ConnectionError),
975}
976
977#[derive(Debug, Error, Clone, Eq, PartialEq)]
979pub enum OpenStreamError {
980 #[error("connection lost")]
982 ConnectionLost(#[from] ConnectionError),
983 #[error("streams exhausted")]
985 StreamsExhausted,
986}
987
988#[cfg(feature = "h3")]
989pub(crate) mod h3_impl {
990 use std::sync::Arc;
991
992 use compio_buf::bytes::Buf;
993 use futures_util::ready;
994 use h3::{
995 error::Code,
996 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
997 };
998 use h3_datagram::{
999 datagram::EncodedDatagram,
1000 quic_traits::{
1001 DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
1002 },
1003 };
1004
1005 use super::*;
1006 use crate::send_stream::h3_impl::SendStream;
1007
1008 impl From<ConnectionError> for ConnectionErrorIncoming {
1009 fn from(e: ConnectionError) -> Self {
1010 use ConnectionError::*;
1011 match e {
1012 ApplicationClosed(e) => Self::ApplicationClose {
1013 error_code: e.error_code.into_inner(),
1014 },
1015 TimedOut => Self::Timeout,
1016
1017 e => Self::Undefined(Arc::new(e)),
1018 }
1019 }
1020 }
1021
1022 impl From<ConnectionError> for StreamErrorIncoming {
1023 fn from(e: ConnectionError) -> Self {
1024 Self::ConnectionErrorIncoming {
1025 connection_error: e.into(),
1026 }
1027 }
1028 }
1029
1030 impl From<SendDatagramError> for SendDatagramErrorIncoming {
1031 fn from(e: SendDatagramError) -> Self {
1032 use SendDatagramError::*;
1033 match e {
1034 UnsupportedByPeer | Disabled => Self::NotAvailable,
1035 TooLarge => Self::TooLarge,
1036 ConnectionLost(e) => Self::ConnectionError(e.into()),
1037 }
1038 }
1039 }
1040
1041 impl<B> SendDatagram<B> for Connection
1042 where
1043 B: Buf,
1044 {
1045 fn send_datagram<T: Into<EncodedDatagram<B>>>(
1046 &mut self,
1047 data: T,
1048 ) -> Result<(), SendDatagramErrorIncoming> {
1049 let mut buf: EncodedDatagram<B> = data.into();
1050 let buf = buf.copy_to_bytes(buf.remaining());
1051 Ok(Connection::send_datagram(self, buf)?)
1052 }
1053 }
1054
1055 impl RecvDatagram for Connection {
1056 type Buffer = Bytes;
1057
1058 fn poll_incoming_datagram(
1059 &mut self,
1060 cx: &mut core::task::Context<'_>,
1061 ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1062 Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1063 }
1064 }
1065
1066 impl<B: Buf> DatagramConnectionExt<B> for Connection {
1067 type RecvDatagramHandler = Self;
1068 type SendDatagramHandler = Self;
1069
1070 fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1071 self.clone()
1072 }
1073
1074 fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1075 self.clone()
1076 }
1077 }
1078
1079 pub struct BidiStream<B> {
1081 send: SendStream<B>,
1082 recv: RecvStream,
1083 }
1084
1085 impl<B> BidiStream<B> {
1086 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1087 Self {
1088 send: SendStream::new(conn.clone(), stream, is_0rtt),
1089 recv: RecvStream::new(conn, stream, is_0rtt),
1090 }
1091 }
1092 }
1093
1094 impl<B> quic::BidiStream<B> for BidiStream<B>
1095 where
1096 B: Buf,
1097 {
1098 type RecvStream = RecvStream;
1099 type SendStream = SendStream<B>;
1100
1101 fn split(self) -> (Self::SendStream, Self::RecvStream) {
1102 (self.send, self.recv)
1103 }
1104 }
1105
1106 impl<B> quic::RecvStream for BidiStream<B>
1107 where
1108 B: Buf,
1109 {
1110 type Buf = Bytes;
1111
1112 fn poll_data(
1113 &mut self,
1114 cx: &mut Context<'_>,
1115 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1116 self.recv.poll_data(cx)
1117 }
1118
1119 fn stop_sending(&mut self, error_code: u64) {
1120 self.recv.stop_sending(error_code)
1121 }
1122
1123 fn recv_id(&self) -> quic::StreamId {
1124 self.recv.recv_id()
1125 }
1126 }
1127
1128 impl<B> quic::SendStream<B> for BidiStream<B>
1129 where
1130 B: Buf,
1131 {
1132 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1133 self.send.poll_ready(cx)
1134 }
1135
1136 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1137 self.send.send_data(data)
1138 }
1139
1140 fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1141 self.send.poll_finish(cx)
1142 }
1143
1144 fn reset(&mut self, reset_code: u64) {
1145 self.send.reset(reset_code)
1146 }
1147
1148 fn send_id(&self) -> quic::StreamId {
1149 self.send.send_id()
1150 }
1151 }
1152
1153 impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1154 where
1155 B: Buf,
1156 {
1157 fn poll_send<D: Buf>(
1158 &mut self,
1159 cx: &mut Context<'_>,
1160 buf: &mut D,
1161 ) -> Poll<Result<usize, StreamErrorIncoming>> {
1162 self.send.poll_send(cx, buf)
1163 }
1164 }
1165
1166 #[derive(Clone)]
1168 pub struct OpenStreams(Connection);
1169
1170 impl<B> quic::OpenStreams<B> for OpenStreams
1171 where
1172 B: Buf,
1173 {
1174 type BidiStream = BidiStream<B>;
1175 type SendStream = SendStream<B>;
1176
1177 fn poll_open_bidi(
1178 &mut self,
1179 cx: &mut Context<'_>,
1180 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1181 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1182 Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1183 }
1184
1185 fn poll_open_send(
1186 &mut self,
1187 cx: &mut Context<'_>,
1188 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1189 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1190 Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1191 }
1192
1193 fn close(&mut self, code: Code, reason: &[u8]) {
1194 self.0
1195 .close(code.value().try_into().expect("invalid code"), reason)
1196 }
1197 }
1198
1199 impl<B> quic::OpenStreams<B> for Connection
1200 where
1201 B: Buf,
1202 {
1203 type BidiStream = BidiStream<B>;
1204 type SendStream = SendStream<B>;
1205
1206 fn poll_open_bidi(
1207 &mut self,
1208 cx: &mut Context<'_>,
1209 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1210 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1211 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1212 }
1213
1214 fn poll_open_send(
1215 &mut self,
1216 cx: &mut Context<'_>,
1217 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1218 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1219 Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1220 }
1221
1222 fn close(&mut self, code: Code, reason: &[u8]) {
1223 Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1224 }
1225 }
1226
1227 impl<B> quic::Connection<B> for Connection
1228 where
1229 B: Buf,
1230 {
1231 type OpenStreams = OpenStreams;
1232 type RecvStream = RecvStream;
1233
1234 fn poll_accept_recv(
1235 &mut self,
1236 cx: &mut std::task::Context<'_>,
1237 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1238 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1239 Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1240 }
1241
1242 fn poll_accept_bidi(
1243 &mut self,
1244 cx: &mut std::task::Context<'_>,
1245 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1246 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1247 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1248 }
1249
1250 fn opener(&self) -> Self::OpenStreams {
1251 OpenStreams(self.clone())
1252 }
1253 }
1254}