Skip to main content

compio_tls/
stream.rs

1use std::{
2    borrow::Cow,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf};
9use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
10
11#[derive(Debug)]
12#[allow(clippy::large_enum_variant)]
13enum TlsStreamInner<S: Splittable> {
14    #[cfg(feature = "native-tls")]
15    NativeTls(crate::native::TlsStream<Pin<Box<AsyncStream<S>>>>),
16    #[cfg(feature = "rustls")]
17    Rustls(futures_rustls::TlsStream<Pin<Box<AsyncStream<S>>>>),
18    #[cfg(feature = "py-dynamic-openssl")]
19    PyDynamicOpenSsl(crate::py_ossl::TlsStream<Pin<Box<AsyncStream<S>>>>),
20    #[cfg(not(any(
21        feature = "native-tls",
22        feature = "rustls",
23        feature = "py-dynamic-openssl",
24    )))]
25    None(
26        std::convert::Infallible,
27        std::marker::PhantomData<Pin<Box<AsyncStream<S>>>>,
28    ),
29}
30
31impl<S: Splittable + 'static> TlsStreamInner<S>
32where
33    S::ReadHalf: AsyncRead + Unpin,
34    S::WriteHalf: AsyncWrite + Unpin,
35{
36    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
37        match self {
38            #[cfg(feature = "native-tls")]
39            Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
40            #[cfg(feature = "rustls")]
41            Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from),
42            #[cfg(feature = "py-dynamic-openssl")]
43            Self::PyDynamicOpenSsl(s) => s.negotiated_alpn().map(Cow::from),
44            #[cfg(not(any(
45                feature = "native-tls",
46                feature = "rustls",
47                feature = "py-dynamic-openssl",
48            )))]
49            Self::None(f, ..) => match *f {},
50        }
51    }
52}
53
54/// A wrapper around an underlying raw stream which implements the TLS or SSL
55/// protocol.
56///
57/// A `TlsStream<S>` represents a handshake that has been completed successfully
58/// and both the server and the client are ready for receiving and sending
59/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
60/// to a `TlsStream` are encrypted when passing through to `S`.
61#[derive(Debug)]
62pub struct TlsStream<S: Splittable>(TlsStreamInner<S>);
63
64impl<S: Splittable + 'static> TlsStream<S>
65where
66    S::ReadHalf: AsyncRead + Unpin,
67    S::WriteHalf: AsyncWrite + Unpin,
68{
69    /// Returns the negotiated ALPN protocol.
70    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
71        self.0.negotiated_alpn()
72    }
73}
74
75#[cfg(feature = "native-tls")]
76#[doc(hidden)]
77impl<S: Splittable> From<crate::native::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
78    fn from(value: crate::native::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
79        Self(TlsStreamInner::NativeTls(value))
80    }
81}
82
83#[cfg(feature = "rustls")]
84#[doc(hidden)]
85impl<S: Splittable> From<futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>>
86    for TlsStream<S>
87{
88    fn from(value: futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
89        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client(
90            value,
91        )))
92    }
93}
94
95#[cfg(feature = "rustls")]
96#[doc(hidden)]
97impl<S: Splittable> From<futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>>
98    for TlsStream<S>
99{
100    fn from(value: futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
101        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server(
102            value,
103        )))
104    }
105}
106
107#[cfg(feature = "py-dynamic-openssl")]
108#[doc(hidden)]
109impl<S: Splittable> From<crate::py_ossl::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
110    fn from(value: crate::py_ossl::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
111        Self(TlsStreamInner::PyDynamicOpenSsl(value))
112    }
113}
114
115impl<S: Splittable + 'static> futures_util::AsyncRead for TlsStream<S>
116where
117    S::ReadHalf: AsyncRead + Unpin,
118    S::WriteHalf: AsyncWrite + Unpin,
119{
120    fn poll_read(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123        buf: &mut [u8],
124    ) -> Poll<io::Result<usize>> {
125        match &mut self.get_mut().0 {
126            #[cfg(feature = "native-tls")]
127            TlsStreamInner::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
128            #[cfg(feature = "rustls")]
129            TlsStreamInner::Rustls(s) => Pin::new(s).poll_read(cx, buf),
130            #[cfg(feature = "py-dynamic-openssl")]
131            TlsStreamInner::PyDynamicOpenSsl(s) => Pin::new(s).poll_read(cx, buf),
132            #[cfg(not(any(
133                feature = "native-tls",
134                feature = "rustls",
135                feature = "py-dynamic-openssl",
136            )))]
137            TlsStreamInner::None(f, ..) => match *f {},
138        }
139    }
140}
141
142pub(crate) async fn read_futures<S: futures_util::AsyncRead + Unpin, B: IoBufMut>(
143    s: &mut S,
144    mut buf: B,
145) -> BufResult<usize, B> {
146    let slice = buf.ensure_init();
147    let res = futures_util::AsyncReadExt::read(s, slice).await;
148    let res = match res {
149        Ok(len) => {
150            unsafe { buf.advance_to(len) };
151            Ok(len)
152        }
153        // TLS streams may return UnexpectedEof when the connection is closed.
154        // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
155        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
156        _ => res,
157    };
158    BufResult(res, buf)
159}
160
161impl<S: Splittable + 'static> AsyncRead for TlsStream<S>
162where
163    S::ReadHalf: AsyncRead + Unpin,
164    S::WriteHalf: AsyncWrite + Unpin,
165{
166    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
167        read_futures(self, buf).await
168    }
169}
170
171impl<S: Splittable + 'static> futures_util::AsyncWrite for TlsStream<S>
172where
173    S::ReadHalf: AsyncRead + Unpin,
174    S::WriteHalf: AsyncWrite + Unpin,
175{
176    fn poll_write(
177        self: Pin<&mut Self>,
178        cx: &mut Context<'_>,
179        buf: &[u8],
180    ) -> Poll<io::Result<usize>> {
181        match &mut self.get_mut().0 {
182            #[cfg(feature = "native-tls")]
183            TlsStreamInner::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
184            #[cfg(feature = "rustls")]
185            TlsStreamInner::Rustls(s) => Pin::new(s).poll_write(cx, buf),
186            #[cfg(feature = "py-dynamic-openssl")]
187            TlsStreamInner::PyDynamicOpenSsl(s) => Pin::new(s).poll_write(cx, buf),
188            #[cfg(not(any(
189                feature = "native-tls",
190                feature = "rustls",
191                feature = "py-dynamic-openssl",
192            )))]
193            TlsStreamInner::None(f, ..) => match *f {},
194        }
195    }
196
197    fn poll_write_vectored(
198        self: Pin<&mut Self>,
199        cx: &mut Context<'_>,
200        bufs: &[io::IoSlice<'_>],
201    ) -> Poll<io::Result<usize>> {
202        match &mut self.get_mut().0 {
203            #[cfg(feature = "native-tls")]
204            TlsStreamInner::NativeTls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
205            #[cfg(feature = "rustls")]
206            TlsStreamInner::Rustls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
207            #[cfg(feature = "py-dynamic-openssl")]
208            TlsStreamInner::PyDynamicOpenSsl(s) => Pin::new(s).poll_write_vectored(cx, bufs),
209            #[cfg(not(any(
210                feature = "native-tls",
211                feature = "rustls",
212                feature = "py-dynamic-openssl",
213            )))]
214            TlsStreamInner::None(f, ..) => match *f {},
215        }
216    }
217
218    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219        match &mut self.get_mut().0 {
220            #[cfg(feature = "native-tls")]
221            TlsStreamInner::NativeTls(s) => Pin::new(s).poll_flush(cx),
222            #[cfg(feature = "rustls")]
223            TlsStreamInner::Rustls(s) => Pin::new(s).poll_flush(cx),
224            #[cfg(feature = "py-dynamic-openssl")]
225            TlsStreamInner::PyDynamicOpenSsl(s) => Pin::new(s).poll_flush(cx),
226            #[cfg(not(any(
227                feature = "native-tls",
228                feature = "rustls",
229                feature = "py-dynamic-openssl",
230            )))]
231            TlsStreamInner::None(f, ..) => match *f {},
232        }
233    }
234
235    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236        match &mut self.get_mut().0 {
237            #[cfg(feature = "native-tls")]
238            TlsStreamInner::NativeTls(s) => Pin::new(s).poll_close(cx),
239            #[cfg(feature = "rustls")]
240            TlsStreamInner::Rustls(s) => Pin::new(s).poll_close(cx),
241            #[cfg(feature = "py-dynamic-openssl")]
242            TlsStreamInner::PyDynamicOpenSsl(s) => Pin::new(s).poll_close(cx),
243            #[cfg(not(any(
244                feature = "native-tls",
245                feature = "rustls",
246                feature = "py-dynamic-openssl",
247            )))]
248            TlsStreamInner::None(f, ..) => match *f {},
249        }
250    }
251}
252
253impl<S: Splittable + 'static> AsyncWrite for TlsStream<S>
254where
255    S::ReadHalf: AsyncRead + Unpin,
256    S::WriteHalf: AsyncWrite + Unpin,
257{
258    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
259        let slice = buf.as_init();
260        let res = futures_util::AsyncWriteExt::write(self, slice).await;
261        BufResult(res, buf)
262    }
263
264    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
265        let slices = buf.iter_slice().map(io::IoSlice::new).collect::<Vec<_>>();
266        let res = futures_util::AsyncWriteExt::write_vectored(self, &slices).await;
267        BufResult(res, buf)
268    }
269
270    async fn flush(&mut self) -> io::Result<()> {
271        futures_util::AsyncWriteExt::flush(self).await
272    }
273
274    async fn shutdown(&mut self) -> io::Result<()> {
275        futures_util::AsyncWriteExt::close(self).await
276    }
277}