Skip to main content

compio_tls/
adapter.rs

1use std::{fmt::Debug, io};
2
3use compio_io::{
4    AsyncRead, AsyncWrite,
5    compat::{AsyncStream, SyncStream},
6};
7
8use crate::TlsStream;
9
10#[derive(Clone)]
11enum TlsConnectorInner {
12    #[cfg(feature = "native-tls")]
13    NativeTls(native_tls::TlsConnector),
14    #[cfg(feature = "rustls")]
15    Rustls(futures_rustls::TlsConnector),
16    #[cfg(feature = "py-dynamic-openssl")]
17    PyDynamicOpenSsl(compio_py_dynamic_openssl::SSLContext),
18    #[cfg(not(any(
19        feature = "native-tls",
20        feature = "rustls",
21        feature = "py-dynamic-openssl"
22    )))]
23    None(std::convert::Infallible),
24}
25
26impl Debug for TlsConnectorInner {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            #[cfg(feature = "native-tls")]
30            Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
31            #[cfg(feature = "rustls")]
32            Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
33            #[cfg(feature = "py-dynamic-openssl")]
34            Self::PyDynamicOpenSsl(_) => f.debug_tuple("PyDynamicOpenSsl").finish(),
35            #[cfg(not(any(
36                feature = "native-tls",
37                feature = "rustls",
38                feature = "py-dynamic-openssl"
39            )))]
40            Self::None(f) => match *f {},
41        }
42    }
43}
44
45/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`],
46/// providing an async `connect` method.
47#[derive(Debug, Clone)]
48pub struct TlsConnector(TlsConnectorInner);
49
50#[cfg(feature = "native-tls")]
51impl From<native_tls::TlsConnector> for TlsConnector {
52    fn from(value: native_tls::TlsConnector) -> Self {
53        Self(TlsConnectorInner::NativeTls(value))
54    }
55}
56
57#[cfg(feature = "rustls")]
58impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
59    fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
60        Self(TlsConnectorInner::Rustls(value.into()))
61    }
62}
63
64#[cfg(feature = "py-dynamic-openssl")]
65#[doc(hidden)]
66impl From<compio_py_dynamic_openssl::SSLContext> for TlsConnector {
67    fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
68        Self(TlsConnectorInner::PyDynamicOpenSsl(value))
69    }
70}
71
72impl TlsConnector {
73    /// Connects the provided stream with this connector, assuming the provided
74    /// domain.
75    ///
76    /// This function will internally call `TlsConnector::connect` to connect
77    /// the stream and returns a future representing the resolution of the
78    /// connection operation. The returned future will resolve to either
79    /// `TlsStream<S>` or `Error` depending if it's successful or not.
80    ///
81    /// This is typically used for clients who have already established, for
82    /// example, a TCP connection to a remote server. That stream is then
83    /// provided here to perform the client half of a connection to a
84    /// TLS-powered server.
85    pub async fn connect<S: AsyncRead + AsyncWrite + Unpin + 'static>(
86        &self,
87        domain: &str,
88        stream: S,
89    ) -> io::Result<TlsStream<S>>
90    where
91        for<'a> &'a S: AsyncRead + AsyncWrite,
92    {
93        match &self.0 {
94            #[cfg(feature = "native-tls")]
95            TlsConnectorInner::NativeTls(c) => {
96                handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
97            }
98            #[cfg(feature = "rustls")]
99            TlsConnectorInner::Rustls(c) => {
100                let client = c
101                    .connect(
102                        domain.to_string().try_into().map_err(io::Error::other)?,
103                        Box::pin(AsyncStream::new(stream)),
104                    )
105                    .await?;
106                Ok(TlsStream::from(client))
107            }
108            #[cfg(feature = "py-dynamic-openssl")]
109            TlsConnectorInner::PyDynamicOpenSsl(c) => {
110                crate::py_ossl::handshake(c.connect(domain, SyncStream::new(stream))).await
111            }
112            #[cfg(not(any(
113                feature = "native-tls",
114                feature = "rustls",
115                feature = "py-dynamic-openssl"
116            )))]
117            TlsConnectorInner::None(f) => match *f {},
118        }
119    }
120}
121
122#[derive(Clone)]
123enum TlsAcceptorInner {
124    #[cfg(feature = "native-tls")]
125    NativeTls(native_tls::TlsAcceptor),
126    #[cfg(feature = "rustls")]
127    Rustls(futures_rustls::TlsAcceptor),
128    #[cfg(feature = "py-dynamic-openssl")]
129    PyDynamicOpenSsl(compio_py_dynamic_openssl::SSLContext),
130    #[cfg(not(any(
131        feature = "native-tls",
132        feature = "rustls",
133        feature = "py-dynamic-openssl"
134    )))]
135    None(std::convert::Infallible),
136}
137
138/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`],
139/// providing an async `accept` method.
140///
141/// [`native_tls::TlsAcceptor`]: https://docs.rs/native-tls/latest/native_tls/struct.TlsAcceptor.html
142/// [`rustls::ServerConfig`]: https://docs.rs/rustls/latest/rustls/server/struct.ServerConfig.html
143#[derive(Clone)]
144pub struct TlsAcceptor(TlsAcceptorInner);
145
146#[cfg(feature = "native-tls")]
147impl From<native_tls::TlsAcceptor> for TlsAcceptor {
148    fn from(value: native_tls::TlsAcceptor) -> Self {
149        Self(TlsAcceptorInner::NativeTls(value))
150    }
151}
152
153#[cfg(feature = "rustls")]
154impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
155    fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
156        Self(TlsAcceptorInner::Rustls(value.into()))
157    }
158}
159
160#[cfg(feature = "py-dynamic-openssl")]
161impl From<compio_py_dynamic_openssl::SSLContext> for TlsAcceptor {
162    fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
163        Self(TlsAcceptorInner::PyDynamicOpenSsl(value))
164    }
165}
166
167impl TlsAcceptor {
168    /// Accepts a new client connection with the provided stream.
169    ///
170    /// This function will internally call `TlsAcceptor::accept` to connect
171    /// the stream and returns a future representing the resolution of the
172    /// connection operation. The returned future will resolve to either
173    /// `TlsStream<S>` or `Error` depending if it's successful or not.
174    ///
175    /// This is typically used after a new socket has been accepted from a
176    /// `TcpListener`. That socket is then passed to this function to perform
177    /// the server half of accepting a client connection.
178    pub async fn accept<S: AsyncRead + AsyncWrite + Unpin + 'static>(
179        &self,
180        stream: S,
181    ) -> io::Result<TlsStream<S>>
182    where
183        for<'a> &'a S: AsyncRead + AsyncWrite,
184    {
185        match &self.0 {
186            #[cfg(feature = "native-tls")]
187            TlsAcceptorInner::NativeTls(c) => {
188                handshake_native_tls(c.accept(SyncStream::new(stream))).await
189            }
190            #[cfg(feature = "rustls")]
191            TlsAcceptorInner::Rustls(c) => {
192                let server = c.accept(Box::pin(AsyncStream::new(stream))).await?;
193                Ok(TlsStream::from(server))
194            }
195            #[cfg(feature = "py-dynamic-openssl")]
196            TlsAcceptorInner::PyDynamicOpenSsl(a) => {
197                crate::py_ossl::handshake(a.accept(SyncStream::new(stream))).await
198            }
199            #[cfg(not(any(
200                feature = "native-tls",
201                feature = "rustls",
202                feature = "py-dynamic-openssl"
203            )))]
204            TlsAcceptorInner::None(f) => match *f {},
205        }
206    }
207}
208
209#[cfg(feature = "native-tls")]
210async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
211    mut res: Result<
212        native_tls::TlsStream<SyncStream<S>>,
213        native_tls::HandshakeError<SyncStream<S>>,
214    >,
215) -> io::Result<TlsStream<S>> {
216    use native_tls::HandshakeError;
217
218    loop {
219        match res {
220            Ok(mut s) => {
221                let inner = s.get_mut();
222                if inner.has_pending_write() {
223                    inner.flush_write_buf().await?;
224                }
225                return Ok(TlsStream::from(s));
226            }
227            Err(e) => match e {
228                HandshakeError::Failure(e) => return Err(io::Error::other(e)),
229                HandshakeError::WouldBlock(mut mid_stream) => {
230                    let s = mid_stream.get_mut();
231                    if s.has_pending_write() {
232                        s.flush_write_buf().await?;
233                    } else {
234                        s.fill_read_buf().await?;
235                    }
236                    res = mid_stream.handshake();
237                }
238            },
239        }
240    }
241}