Skip to main content

compio_ws/
tls.rs

1//! TLS support for WebSocket connections (native-tls and rustls).
2
3use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
4use compio_net::TcpStream;
5use compio_tls::{MaybeTlsStream, TlsConnector};
6use tungstenite::{
7    Error,
8    client::{IntoClientRequest, uri_mode},
9    handshake::client::{Request, Response},
10    stream::Mode,
11};
12
13use crate::{Config, WebSocketStream, client_async_with_config};
14
15mod encryption {
16    #[cfg(feature = "native-tls")]
17    pub mod native_tls {
18        use compio_tls::{TlsConnector, native_tls};
19        use tungstenite::{Error, error::TlsError};
20
21        pub fn new_connector() -> Result<TlsConnector, Error> {
22            let native_connector = native_tls::TlsConnector::new().map_err(TlsError::from)?;
23            Ok(TlsConnector::from(native_connector))
24        }
25    }
26
27    #[cfg(feature = "rustls")]
28    pub mod rustls {
29        use std::sync::Arc;
30
31        use compio_tls::{
32            TlsConnector,
33            rustls::{ClientConfig, RootCertStore},
34        };
35        use tungstenite::Error;
36
37        fn config_with_certs() -> Result<Arc<ClientConfig>, Error> {
38            #[allow(unused_mut)]
39            let mut root_store = RootCertStore::empty();
40            #[cfg(feature = "rustls-native-certs")]
41            {
42                let rustls_native_certs::CertificateResult { certs, errors, .. } =
43                    rustls_native_certs::load_native_certs();
44
45                if !errors.is_empty() {
46                    compio_log::warn!("native root CA certificate loading errors: {errors:?}");
47                }
48
49                // Not finding any native root CA certificates is not fatal
50                // if the "webpki-roots" feature is enabled.
51                #[cfg(not(feature = "webpki-roots"))]
52                if certs.is_empty() {
53                    return Err(std::io::Error::new(
54                        std::io::ErrorKind::NotFound,
55                        format!("no native root CA certificates found (errors: {errors:?})"),
56                    )
57                    .into());
58                }
59
60                let total_number = certs.len();
61                let (number_added, number_ignored) = root_store.add_parsable_certificates(certs);
62                compio_log::debug!(
63                    "Added {number_added}/{total_number} native root certificates (ignored \
64                     {number_ignored})"
65                );
66            }
67            #[cfg(feature = "webpki-roots")]
68            {
69                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
70            }
71
72            Ok(Arc::new(
73                ClientConfig::builder()
74                    .with_root_certificates(root_store)
75                    .with_no_client_auth(),
76            ))
77        }
78
79        #[cfg(feature = "rustls-platform-verifier")]
80        fn config_with_platform_verifier() -> Result<Arc<ClientConfig>, Error> {
81            use rustls_platform_verifier::BuilderVerifierExt;
82
83            // Use platform's native certificate verification
84            // This provides better security and enterprise integration
85            let config_result = ClientConfig::builder()
86                .with_platform_verifier()
87                .map_err(tungstenite::error::TlsError::from)?;
88            Ok(Arc::new(config_result.with_no_client_auth()))
89        }
90
91        pub fn new_connector() -> Result<TlsConnector, Error> {
92            // Create TLS connector with platform verifier when feature is enabled
93            #[cfg(feature = "rustls-platform-verifier")]
94            {
95                let config = match config_with_platform_verifier() {
96                    Ok(config_builder) => config_builder,
97                    Err(e) => {
98                        compio_log::warn!("Error creating platform verifier: {e:?}");
99                        config_with_certs()?
100                    }
101                };
102                Ok(TlsConnector::from(config))
103            }
104            #[cfg(not(feature = "rustls-platform-verifier"))]
105            {
106                // Create TLS connector with certs from enabled features
107                let config = config_with_certs()?;
108                Ok(TlsConnector::from(config))
109            }
110        }
111    }
112}
113
114async fn wrap_stream<S>(
115    socket: S,
116    domain: &str,
117    connector: Option<TlsConnector>,
118    mode: Mode,
119    cap: usize,
120    max_buffer_size: usize,
121) -> Result<MaybeTlsStream<S>, Error>
122where
123    S: Splittable + 'static,
124    S::ReadHalf: AsyncRead + Unpin,
125    S::WriteHalf: AsyncWrite + Unpin,
126{
127    let socket = AsyncStream::with_limits(cap, max_buffer_size, socket);
128    match mode {
129        Mode::Plain => Ok(MaybeTlsStream::new_plain_compat(socket)),
130        Mode::Tls => {
131            let stream = {
132                let connector = if let Some(connector) = connector {
133                    connector
134                } else {
135                    #[cfg(feature = "native-tls")]
136                    {
137                        match encryption::native_tls::new_connector() {
138                            Ok(c) => c,
139                            Err(e) => {
140                                compio_log::warn!(
141                                    "Falling back to rustls TLS connector due to native-tls \
142                                     error: {e:?}",
143                                );
144                                #[cfg(feature = "rustls")]
145                                {
146                                    encryption::rustls::new_connector()?
147                                }
148                                #[cfg(not(feature = "rustls"))]
149                                {
150                                    return Err(e);
151                                }
152                            }
153                        }
154                    }
155                    #[cfg(all(feature = "rustls", not(feature = "native-tls")))]
156                    {
157                        encryption::rustls::new_connector()?
158                    }
159                    #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
160                    {
161                        return Err(Error::Url(
162                            tungstenite::error::UrlError::TlsFeatureNotEnabled,
163                        ));
164                    }
165                };
166
167                connector
168                    .connect_compat(domain, socket)
169                    .await
170                    .map_err(Error::Io)?
171            };
172            Ok(MaybeTlsStream::new_tls(stream))
173        }
174    }
175}
176
177/// Creates a WebSocket handshake from a request and a stream,
178/// upgrading the stream to TLS if required.
179pub async fn client_async_tls<R, S>(
180    request: R,
181    stream: S,
182) -> Result<(WebSocketStream<S>, Response), Error>
183where
184    R: IntoClientRequest,
185    S: Splittable + 'static,
186    S::ReadHalf: AsyncRead + Unpin,
187    S::WriteHalf: AsyncWrite + Unpin,
188{
189    client_async_tls_with_config(request, stream, None, None).await
190}
191
192/// Similar to [`client_async_tls()`] but the one can specify a websocket
193/// configuration, and an optional connector.
194pub async fn client_async_tls_with_config<R, S>(
195    request: R,
196    stream: S,
197    connector: Option<TlsConnector>,
198    config: impl Into<Config>,
199) -> Result<(WebSocketStream<S>, Response), Error>
200where
201    R: IntoClientRequest,
202    S: Splittable + 'static,
203    S::ReadHalf: AsyncRead + Unpin,
204    S::WriteHalf: AsyncWrite + Unpin,
205{
206    let request: Request = request.into_client_request()?;
207
208    let domain = domain(&request)?;
209
210    let mode = uri_mode(request.uri())?;
211
212    let config = config.into();
213
214    let stream = wrap_stream(
215        stream,
216        domain,
217        connector,
218        mode,
219        config.buffer_size_base,
220        config.buffer_size_limit,
221    )
222    .await?;
223    client_async_with_config(request, stream, config).await
224}
225
226/// Connect to a given URL.
227pub async fn connect_async<R>(request: R) -> Result<(WebSocketStream<TcpStream>, Response), Error>
228where
229    R: IntoClientRequest,
230{
231    connect_async_with_config(request, None).await
232}
233
234/// Similar to [`connect_async()`], but user can specify a [`Config`].
235pub async fn connect_async_with_config<R>(
236    request: R,
237    config: impl Into<Config>,
238) -> Result<(WebSocketStream<TcpStream>, Response), Error>
239where
240    R: IntoClientRequest,
241{
242    connect_async_tls_with_config(request, config, None).await
243}
244
245/// Similar to [`connect_async()`], but user can specify a [`Config`] and an
246/// optional [`TlsConnector`].
247pub async fn connect_async_tls_with_config<R>(
248    request: R,
249    config: impl Into<Config>,
250    connector: Option<TlsConnector>,
251) -> Result<(WebSocketStream<TcpStream>, Response), Error>
252where
253    R: IntoClientRequest,
254{
255    let config = config.into();
256    let request: Request = request.into_client_request()?;
257
258    // We don't check if it's an IPv6 address because `std` handles it internally.
259    let domain = request
260        .uri()
261        .host()
262        .ok_or(Error::Url(tungstenite::error::UrlError::NoHostName))?;
263    let port = port(&request)?;
264
265    let socket = TcpStream::connect((domain, port))
266        .await
267        .map_err(Error::Io)?;
268    socket.set_nodelay(config.disable_nagle)?;
269    client_async_tls_with_config(request, socket, connector, config).await
270}
271
272#[inline]
273fn port(request: &Request) -> Result<u16, Error> {
274    request
275        .uri()
276        .port_u16()
277        .or_else(|| match uri_mode(request.uri()).ok()? {
278            Mode::Plain => Some(80),
279            Mode::Tls => Some(443),
280        })
281        .ok_or(Error::Url(
282            tungstenite::error::UrlError::UnsupportedUrlScheme,
283        ))
284}
285
286#[inline]
287fn domain(request: &Request) -> Result<&str, Error> {
288    request
289        .uri()
290        .host()
291        .map(|host| {
292            // If host is an IPv6 address, it might be surrounded by brackets. These
293            // brackets are *not* part of a valid IP, so they must be stripped
294            // out.
295            //
296            // The URI from the request is guaranteed to be valid, so we don't need a
297            // separate check for the closing bracket.
298
299            if host.starts_with('[') && host.ends_with(']') {
300                &host[1..host.len() - 1]
301            } else {
302                host
303            }
304        })
305        .ok_or(tungstenite::Error::Url(
306            tungstenite::error::UrlError::NoHostName,
307        ))
308}