Skip to main content

compio_tls/
adapter.rs

1use std::{fmt::Debug, io};
2
3use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
4
5use crate::TlsStream;
6
7#[derive(Clone)]
8enum TlsConnectorInner {
9    #[cfg(feature = "native-tls")]
10    NativeTls(crate::native::TlsConnector),
11    #[cfg(feature = "rustls")]
12    Rustls(futures_rustls::TlsConnector),
13    #[cfg(feature = "py-dynamic-openssl")]
14    PyDynamicOpenSsl(crate::py_ossl::TlsConnector),
15    #[cfg(not(any(
16        feature = "native-tls",
17        feature = "rustls",
18        feature = "py-dynamic-openssl"
19    )))]
20    None(std::convert::Infallible),
21}
22
23impl Debug for TlsConnectorInner {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            #[cfg(feature = "native-tls")]
27            Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
28            #[cfg(feature = "rustls")]
29            Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
30            #[cfg(feature = "py-dynamic-openssl")]
31            Self::PyDynamicOpenSsl(_) => f.debug_tuple("PyDynamicOpenSsl").finish(),
32            #[cfg(not(any(
33                feature = "native-tls",
34                feature = "rustls",
35                feature = "py-dynamic-openssl"
36            )))]
37            Self::None(f) => match *f {},
38        }
39    }
40}
41
42/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`],
43/// providing an async `connect` method.
44#[derive(Debug, Clone)]
45pub struct TlsConnector(TlsConnectorInner);
46
47#[cfg(feature = "native-tls")]
48impl From<native_tls::TlsConnector> for TlsConnector {
49    fn from(value: native_tls::TlsConnector) -> Self {
50        Self(TlsConnectorInner::NativeTls(value.into()))
51    }
52}
53
54#[cfg(feature = "rustls")]
55impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
56    fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
57        Self(TlsConnectorInner::Rustls(value.into()))
58    }
59}
60
61#[cfg(feature = "py-dynamic-openssl")]
62#[doc(hidden)]
63impl From<compio_py_dynamic_openssl::SSLContext> for TlsConnector {
64    fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
65        Self(TlsConnectorInner::PyDynamicOpenSsl(value.into()))
66    }
67}
68
69impl TlsConnector {
70    /// Connects the provided stream with this connector, assuming the provided
71    /// domain.
72    ///
73    /// This function will internally call `TlsConnector::connect` to connect
74    /// the stream and returns a future representing the resolution of the
75    /// connection operation. The returned future will resolve to either
76    /// `TlsStream<S>` or `Error` depending if it's successful or not.
77    ///
78    /// This is typically used for clients who have already established, for
79    /// example, a TCP connection to a remote server. That stream is then
80    /// provided here to perform the client half of a connection to a
81    /// TLS-powered server.
82    pub async fn connect<S: Splittable + 'static>(
83        &self,
84        domain: &str,
85        stream: S,
86    ) -> io::Result<TlsStream<S>>
87    where
88        S::ReadHalf: AsyncRead + Unpin,
89        S::WriteHalf: AsyncWrite + Unpin,
90    {
91        self.connect_compat(domain, AsyncStream::new(stream)).await
92    }
93
94    /// Similar to `connect` but accepts an [`AsyncStream`] instead of a raw
95    /// stream. Users are free to adjust the inner buffer sizes and limits.
96    pub async fn connect_compat<S: Splittable + 'static>(
97        &self,
98        domain: &str,
99        stream: AsyncStream<S>,
100    ) -> io::Result<TlsStream<S>>
101    where
102        S::ReadHalf: AsyncRead + Unpin,
103        S::WriteHalf: AsyncWrite + Unpin,
104    {
105        match &self.0 {
106            #[cfg(feature = "native-tls")]
107            TlsConnectorInner::NativeTls(c) => {
108                let client = c.connect(domain, Box::pin(stream)).await?;
109                Ok(TlsStream::from(client))
110            }
111            #[cfg(feature = "rustls")]
112            TlsConnectorInner::Rustls(c) => {
113                let client = c
114                    .connect(
115                        domain.to_string().try_into().map_err(io::Error::other)?,
116                        Box::pin(stream),
117                    )
118                    .await?;
119                Ok(TlsStream::from(client))
120            }
121            #[cfg(feature = "py-dynamic-openssl")]
122            TlsConnectorInner::PyDynamicOpenSsl(c) => {
123                let client = c.connect(domain, Box::pin(stream)).await?;
124                Ok(TlsStream::from(client))
125            }
126            #[cfg(not(any(
127                feature = "native-tls",
128                feature = "rustls",
129                feature = "py-dynamic-openssl"
130            )))]
131            TlsConnectorInner::None(f) => match *f {},
132        }
133    }
134}
135
136#[derive(Clone)]
137enum TlsAcceptorInner {
138    #[cfg(feature = "native-tls")]
139    NativeTls(crate::native::TlsAcceptor),
140    #[cfg(feature = "rustls")]
141    Rustls(futures_rustls::TlsAcceptor),
142    #[cfg(feature = "py-dynamic-openssl")]
143    PyDynamicOpenSsl(crate::py_ossl::TlsAcceptor),
144    #[cfg(not(any(
145        feature = "native-tls",
146        feature = "rustls",
147        feature = "py-dynamic-openssl"
148    )))]
149    None(std::convert::Infallible),
150}
151
152/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`],
153/// providing an async `accept` method.
154///
155/// [`native_tls::TlsAcceptor`]: https://docs.rs/native-tls/latest/native_tls/struct.TlsAcceptor.html
156/// [`rustls::ServerConfig`]: https://docs.rs/rustls/latest/rustls/server/struct.ServerConfig.html
157#[derive(Clone)]
158pub struct TlsAcceptor(TlsAcceptorInner);
159
160#[cfg(feature = "native-tls")]
161impl From<native_tls::TlsAcceptor> for TlsAcceptor {
162    fn from(value: native_tls::TlsAcceptor) -> Self {
163        Self(TlsAcceptorInner::NativeTls(value.into()))
164    }
165}
166
167#[cfg(feature = "rustls")]
168impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
169    fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
170        Self(TlsAcceptorInner::Rustls(value.into()))
171    }
172}
173
174#[cfg(feature = "py-dynamic-openssl")]
175impl From<compio_py_dynamic_openssl::SSLContext> for TlsAcceptor {
176    fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
177        Self(TlsAcceptorInner::PyDynamicOpenSsl(value.into()))
178    }
179}
180
181impl TlsAcceptor {
182    /// Accepts a new client connection with the provided stream.
183    ///
184    /// This function will internally call `TlsAcceptor::accept` to connect
185    /// the stream and returns a future representing the resolution of the
186    /// connection operation. The returned future will resolve to either
187    /// `TlsStream<S>` or `Error` depending if it's successful or not.
188    ///
189    /// This is typically used after a new socket has been accepted from a
190    /// `TcpListener`. That socket is then passed to this function to perform
191    /// the server half of accepting a client connection.
192    pub async fn accept<S: Splittable + 'static>(&self, stream: S) -> io::Result<TlsStream<S>>
193    where
194        S::ReadHalf: AsyncRead + Unpin,
195        S::WriteHalf: AsyncWrite + Unpin,
196    {
197        self.accept_compat(AsyncStream::new(stream)).await
198    }
199
200    /// Similar to `accept` but accepts an [`AsyncStream`] instead of a raw
201    /// stream. Users are free to adjust the inner buffer sizes and limits.
202    pub async fn accept_compat<S: Splittable + 'static>(
203        &self,
204        stream: AsyncStream<S>,
205    ) -> io::Result<TlsStream<S>>
206    where
207        S::ReadHalf: AsyncRead + Unpin,
208        S::WriteHalf: AsyncWrite + Unpin,
209    {
210        match &self.0 {
211            #[cfg(feature = "native-tls")]
212            TlsAcceptorInner::NativeTls(c) => {
213                let server = c.accept(Box::pin(stream)).await?;
214                Ok(TlsStream::from(server))
215            }
216            #[cfg(feature = "rustls")]
217            TlsAcceptorInner::Rustls(c) => {
218                let server = c.accept(Box::pin(stream)).await?;
219                Ok(TlsStream::from(server))
220            }
221            #[cfg(feature = "py-dynamic-openssl")]
222            TlsAcceptorInner::PyDynamicOpenSsl(a) => {
223                let server = a.accept(Box::pin(stream)).await?;
224                Ok(TlsStream::from(server))
225            }
226            #[cfg(not(any(
227                feature = "native-tls",
228                feature = "rustls",
229                feature = "py-dynamic-openssl"
230            )))]
231            TlsAcceptorInner::None(f) => match *f {},
232        }
233    }
234}