Skip to main content

compio_tls/
stream.rs

1#[cfg(feature = "rustls")]
2use std::pin::Pin;
3use std::{borrow::Cow, io, mem::MaybeUninit};
4
5use compio_buf::{BufResult, IoBuf, IoBufMut};
6use compio_io::{
7    AsyncRead, AsyncWrite,
8    compat::{AsyncStream, SyncStream},
9};
10
11#[derive(Debug)]
12#[allow(clippy::large_enum_variant)]
13enum TlsStreamInner<S> {
14    #[cfg(feature = "native-tls")]
15    NativeTls(native_tls::TlsStream<SyncStream<S>>),
16    #[cfg(feature = "rustls")]
17    Rustls(futures_rustls::TlsStream<Pin<Box<AsyncStream<S>>>>),
18    #[cfg(feature = "py-dynamic-openssl")]
19    PyDynamicOpenSsl(compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>),
20    #[cfg(not(any(
21        feature = "native-tls",
22        feature = "rustls",
23        feature = "py-dynamic-openssl",
24    )))]
25    None(std::convert::Infallible, std::marker::PhantomData<S>),
26}
27
28impl<S> TlsStreamInner<S> {
29    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
30        match self {
31            #[cfg(feature = "native-tls")]
32            Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
33            #[cfg(feature = "rustls")]
34            Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from),
35            #[cfg(feature = "py-dynamic-openssl")]
36            Self::PyDynamicOpenSsl(s) => crate::py_ossl::negotiated_alpn(s),
37            #[cfg(not(any(
38                feature = "native-tls",
39                feature = "rustls",
40                feature = "py-dynamic-openssl",
41            )))]
42            Self::None(f, ..) => match *f {},
43        }
44    }
45}
46
47/// A wrapper around an underlying raw stream which implements the TLS or SSL
48/// protocol.
49///
50/// A `TlsStream<S>` represents a handshake that has been completed successfully
51/// and both the server and the client are ready for receiving and sending
52/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
53/// to a `TlsStream` are encrypted when passing through to `S`.
54#[derive(Debug)]
55pub struct TlsStream<S>(TlsStreamInner<S>);
56
57impl<S> TlsStream<S> {
58    /// Returns the negotiated ALPN protocol.
59    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
60        self.0.negotiated_alpn()
61    }
62}
63
64#[cfg(feature = "native-tls")]
65#[doc(hidden)]
66impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
67    fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
68        Self(TlsStreamInner::NativeTls(value))
69    }
70}
71
72#[cfg(feature = "rustls")]
73#[doc(hidden)]
74impl<S> From<futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
75    fn from(value: futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
76        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client(
77            value,
78        )))
79    }
80}
81
82#[cfg(feature = "rustls")]
83#[doc(hidden)]
84impl<S> From<futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
85    fn from(value: futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
86        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server(
87            value,
88        )))
89    }
90}
91
92#[cfg(feature = "py-dynamic-openssl")]
93#[doc(hidden)]
94impl<S> From<compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>> for TlsStream<S> {
95    fn from(value: compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>) -> Self {
96        Self(TlsStreamInner::PyDynamicOpenSsl(value))
97    }
98}
99
100#[cfg(feature = "native-tls")]
101#[inline]
102async fn drive<S, F, T>(s: &mut native_tls::TlsStream<SyncStream<S>>, mut f: F) -> io::Result<T>
103where
104    S: AsyncRead + AsyncWrite,
105    F: FnMut(&mut native_tls::TlsStream<SyncStream<S>>) -> io::Result<T>,
106{
107    loop {
108        match f(s) {
109            Ok(res) => {
110                let s = s.get_mut();
111                if s.has_pending_write() {
112                    s.flush_write_buf().await?;
113                }
114                break Ok(res);
115            }
116            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
117                let s = s.get_mut();
118                if s.has_pending_write() {
119                    s.flush_write_buf().await?;
120                } else {
121                    s.fill_read_buf().await?;
122                }
123            }
124            Err(e) => break Err(e),
125        }
126    }
127}
128
129impl<S: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for TlsStream<S>
130where
131    for<'a> &'a S: AsyncRead + AsyncWrite,
132{
133    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
134        let slice = buf.as_uninit();
135        slice.fill(MaybeUninit::new(0));
136        // SAFETY: The memory has been initialized
137        let slice =
138            unsafe { std::slice::from_raw_parts_mut::<u8>(slice.as_mut_ptr().cast(), slice.len()) };
139        match &mut self.0 {
140            #[cfg(feature = "native-tls")]
141            TlsStreamInner::NativeTls(s) => match drive(s, |s| io::Read::read(s, slice)).await {
142                Ok(res) => {
143                    unsafe { buf.advance_to(res) };
144                    BufResult(Ok(res), buf)
145                }
146                res => BufResult(res, buf),
147            },
148            #[cfg(feature = "rustls")]
149            TlsStreamInner::Rustls(s) => {
150                let res = futures_util::AsyncReadExt::read(s, slice).await;
151                let res = match res {
152                    Ok(len) => {
153                        unsafe { buf.advance_to(len) };
154                        Ok(len)
155                    }
156                    // TLS streams may return UnexpectedEof when the connection is closed.
157                    // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
158                    Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
159                    _ => res,
160                };
161                BufResult(res, buf)
162            }
163            #[cfg(feature = "py-dynamic-openssl")]
164            TlsStreamInner::PyDynamicOpenSsl(s) => match crate::py_ossl::read(s, slice).await {
165                Ok(res) => {
166                    unsafe { buf.advance_to(res) };
167                    BufResult(Ok(res), buf)
168                }
169                Err(e) => BufResult(Err(e), buf),
170            },
171            #[cfg(not(any(
172                feature = "native-tls",
173                feature = "rustls",
174                feature = "py-dynamic-openssl",
175            )))]
176            TlsStreamInner::None(f, ..) => match *f {},
177        }
178    }
179}
180
181#[cfg(feature = "native-tls")]
182async fn flush_impl(s: &mut native_tls::TlsStream<SyncStream<impl AsyncWrite>>) -> io::Result<()> {
183    loop {
184        match io::Write::flush(s) {
185            Ok(()) => break,
186            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
187                s.get_mut().flush_write_buf().await?;
188            }
189            Err(e) => return Err(e),
190        }
191    }
192    s.get_mut().flush_write_buf().await?;
193    Ok(())
194}
195
196impl<S: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for TlsStream<S>
197where
198    for<'a> &'a S: AsyncRead + AsyncWrite,
199{
200    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
201        let slice = buf.as_init();
202        match &mut self.0 {
203            #[cfg(feature = "native-tls")]
204            TlsStreamInner::NativeTls(s) => {
205                let res = drive(s, |s| io::Write::write(s, slice)).await;
206                BufResult(res, buf)
207            }
208            #[cfg(feature = "rustls")]
209            TlsStreamInner::Rustls(s) => {
210                let res = futures_util::AsyncWriteExt::write(s, slice).await;
211                BufResult(res, buf)
212            }
213            #[cfg(feature = "py-dynamic-openssl")]
214            TlsStreamInner::PyDynamicOpenSsl(s) => {
215                let res = crate::py_ossl::write(s, slice).await;
216                BufResult(res, buf)
217            }
218            #[cfg(not(any(
219                feature = "native-tls",
220                feature = "rustls",
221                feature = "py-dynamic-openssl",
222            )))]
223            TlsStreamInner::None(f, ..) => match *f {},
224        }
225    }
226
227    async fn flush(&mut self) -> io::Result<()> {
228        match &mut self.0 {
229            #[cfg(feature = "native-tls")]
230            TlsStreamInner::NativeTls(s) => flush_impl(s).await,
231            #[cfg(feature = "rustls")]
232            TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::flush(s).await,
233            #[cfg(feature = "py-dynamic-openssl")]
234            TlsStreamInner::PyDynamicOpenSsl(s) => s.get_mut().flush_write_buf().await.map(|_| ()),
235            #[cfg(not(any(
236                feature = "native-tls",
237                feature = "rustls",
238                feature = "py-dynamic-openssl",
239            )))]
240            TlsStreamInner::None(f, ..) => match *f {},
241        }
242    }
243
244    async fn shutdown(&mut self) -> io::Result<()> {
245        self.flush().await?;
246        match &mut self.0 {
247            #[cfg(feature = "native-tls")]
248            TlsStreamInner::NativeTls(s) => {
249                // Send close_notify alert, then shutdown the underlying stream.
250                // Note, this implementation is platform-specific relying on how
251                // native-tls handles shutdown. In general, it's consistent on
252                // first call (sending close_notify); but it may or may not block
253                // and wait for the peer to respond with close_notify on any
254                // subsequent calls. Here we just let such behavior propagate,
255                // and suggest the users to call shutdown() at most once.
256                drive(s, |s| s.shutdown()).await?;
257                s.get_mut().get_mut().shutdown().await
258            }
259            #[cfg(feature = "rustls")]
260            TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::close(s).await,
261            #[cfg(feature = "py-dynamic-openssl")]
262            TlsStreamInner::PyDynamicOpenSsl(s) => crate::py_ossl::shutdown(s).await,
263            #[cfg(not(any(
264                feature = "native-tls",
265                feature = "rustls",
266                feature = "py-dynamic-openssl",
267            )))]
268            TlsStreamInner::None(f, ..) => match *f {},
269        }
270    }
271}