1use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
4use compio_net::TcpStream;
5use compio_tls::{MaybeTlsStream, TlsConnector};
6use tungstenite::{
7 Error,
8 client::{IntoClientRequest, uri_mode},
9 handshake::client::{Request, Response},
10 stream::Mode,
11};
12
13use crate::{Config, WebSocketStream, client_async_with_config};
14
15mod encryption {
16 #[cfg(feature = "native-tls")]
17 pub mod native_tls {
18 use compio_tls::{TlsConnector, native_tls};
19 use tungstenite::{Error, error::TlsError};
20
21 pub fn new_connector() -> Result<TlsConnector, Error> {
22 let native_connector = native_tls::TlsConnector::new().map_err(TlsError::from)?;
23 Ok(TlsConnector::from(native_connector))
24 }
25 }
26
27 #[cfg(feature = "rustls")]
28 pub mod rustls {
29 use std::sync::Arc;
30
31 use compio_tls::{
32 TlsConnector,
33 rustls::{ClientConfig, RootCertStore},
34 };
35 use tungstenite::Error;
36
37 fn config_with_certs() -> Result<Arc<ClientConfig>, Error> {
38 #[allow(unused_mut)]
39 let mut root_store = RootCertStore::empty();
40 #[cfg(feature = "rustls-native-certs")]
41 {
42 let rustls_native_certs::CertificateResult { certs, errors, .. } =
43 rustls_native_certs::load_native_certs();
44
45 if !errors.is_empty() {
46 compio_log::warn!("native root CA certificate loading errors: {errors:?}");
47 }
48
49 #[cfg(not(feature = "webpki-roots"))]
52 if certs.is_empty() {
53 return Err(std::io::Error::new(
54 std::io::ErrorKind::NotFound,
55 format!("no native root CA certificates found (errors: {errors:?})"),
56 )
57 .into());
58 }
59
60 let total_number = certs.len();
61 let (number_added, number_ignored) = root_store.add_parsable_certificates(certs);
62 compio_log::debug!(
63 "Added {number_added}/{total_number} native root certificates (ignored \
64 {number_ignored})"
65 );
66 }
67 #[cfg(feature = "webpki-roots")]
68 {
69 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
70 }
71
72 Ok(Arc::new(
73 ClientConfig::builder()
74 .with_root_certificates(root_store)
75 .with_no_client_auth(),
76 ))
77 }
78
79 #[cfg(feature = "rustls-platform-verifier")]
80 fn config_with_platform_verifier() -> Result<Arc<ClientConfig>, Error> {
81 use rustls_platform_verifier::BuilderVerifierExt;
82
83 let config_result = ClientConfig::builder()
86 .with_platform_verifier()
87 .map_err(tungstenite::error::TlsError::from)?;
88 Ok(Arc::new(config_result.with_no_client_auth()))
89 }
90
91 pub fn new_connector() -> Result<TlsConnector, Error> {
92 #[cfg(feature = "rustls-platform-verifier")]
94 {
95 let config = match config_with_platform_verifier() {
96 Ok(config_builder) => config_builder,
97 Err(e) => {
98 compio_log::warn!("Error creating platform verifier: {e:?}");
99 config_with_certs()?
100 }
101 };
102 Ok(TlsConnector::from(config))
103 }
104 #[cfg(not(feature = "rustls-platform-verifier"))]
105 {
106 let config = config_with_certs()?;
108 Ok(TlsConnector::from(config))
109 }
110 }
111 }
112}
113
114async fn wrap_stream<S>(
115 socket: S,
116 domain: &str,
117 connector: Option<TlsConnector>,
118 mode: Mode,
119 cap: usize,
120 max_buffer_size: usize,
121) -> Result<MaybeTlsStream<S>, Error>
122where
123 S: Splittable + 'static,
124 S::ReadHalf: AsyncRead + Unpin,
125 S::WriteHalf: AsyncWrite + Unpin,
126{
127 let socket = AsyncStream::with_limits(cap, max_buffer_size, socket);
128 match mode {
129 Mode::Plain => Ok(MaybeTlsStream::new_plain_compat(socket)),
130 Mode::Tls => {
131 let stream = {
132 let connector = if let Some(connector) = connector {
133 connector
134 } else {
135 #[cfg(feature = "native-tls")]
136 {
137 match encryption::native_tls::new_connector() {
138 Ok(c) => c,
139 Err(e) => {
140 compio_log::warn!(
141 "Falling back to rustls TLS connector due to native-tls \
142 error: {e:?}",
143 );
144 #[cfg(feature = "rustls")]
145 {
146 encryption::rustls::new_connector()?
147 }
148 #[cfg(not(feature = "rustls"))]
149 {
150 return Err(e);
151 }
152 }
153 }
154 }
155 #[cfg(all(feature = "rustls", not(feature = "native-tls")))]
156 {
157 encryption::rustls::new_connector()?
158 }
159 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
160 {
161 return Err(Error::Url(
162 tungstenite::error::UrlError::TlsFeatureNotEnabled,
163 ));
164 }
165 };
166
167 connector
168 .connect_compat(domain, socket)
169 .await
170 .map_err(Error::Io)?
171 };
172 Ok(MaybeTlsStream::new_tls(stream))
173 }
174 }
175}
176
177pub async fn client_async_tls<R, S>(
180 request: R,
181 stream: S,
182) -> Result<(WebSocketStream<S>, Response), Error>
183where
184 R: IntoClientRequest,
185 S: Splittable + 'static,
186 S::ReadHalf: AsyncRead + Unpin,
187 S::WriteHalf: AsyncWrite + Unpin,
188{
189 client_async_tls_with_config(request, stream, None, None).await
190}
191
192pub async fn client_async_tls_with_config<R, S>(
195 request: R,
196 stream: S,
197 connector: Option<TlsConnector>,
198 config: impl Into<Config>,
199) -> Result<(WebSocketStream<S>, Response), Error>
200where
201 R: IntoClientRequest,
202 S: Splittable + 'static,
203 S::ReadHalf: AsyncRead + Unpin,
204 S::WriteHalf: AsyncWrite + Unpin,
205{
206 let request: Request = request.into_client_request()?;
207
208 let domain = domain(&request)?;
209
210 let mode = uri_mode(request.uri())?;
211
212 let config = config.into();
213
214 let stream = wrap_stream(
215 stream,
216 domain,
217 connector,
218 mode,
219 config.buffer_size_base,
220 config.buffer_size_limit,
221 )
222 .await?;
223 client_async_with_config(request, stream, config).await
224}
225
226pub async fn connect_async<R>(request: R) -> Result<(WebSocketStream<TcpStream>, Response), Error>
228where
229 R: IntoClientRequest,
230{
231 connect_async_with_config(request, None).await
232}
233
234pub async fn connect_async_with_config<R>(
236 request: R,
237 config: impl Into<Config>,
238) -> Result<(WebSocketStream<TcpStream>, Response), Error>
239where
240 R: IntoClientRequest,
241{
242 connect_async_tls_with_config(request, config, None).await
243}
244
245pub async fn connect_async_tls_with_config<R>(
248 request: R,
249 config: impl Into<Config>,
250 connector: Option<TlsConnector>,
251) -> Result<(WebSocketStream<TcpStream>, Response), Error>
252where
253 R: IntoClientRequest,
254{
255 let config = config.into();
256 let request: Request = request.into_client_request()?;
257
258 let domain = request
260 .uri()
261 .host()
262 .ok_or(Error::Url(tungstenite::error::UrlError::NoHostName))?;
263 let port = port(&request)?;
264
265 let socket = TcpStream::connect((domain, port))
266 .await
267 .map_err(Error::Io)?;
268 socket.set_nodelay(config.disable_nagle)?;
269 client_async_tls_with_config(request, socket, connector, config).await
270}
271
272#[inline]
273fn port(request: &Request) -> Result<u16, Error> {
274 request
275 .uri()
276 .port_u16()
277 .or_else(|| match uri_mode(request.uri()).ok()? {
278 Mode::Plain => Some(80),
279 Mode::Tls => Some(443),
280 })
281 .ok_or(Error::Url(
282 tungstenite::error::UrlError::UnsupportedUrlScheme,
283 ))
284}
285
286#[inline]
287fn domain(request: &Request) -> Result<&str, Error> {
288 request
289 .uri()
290 .host()
291 .map(|host| {
292 if host.starts_with('[') && host.ends_with(']') {
300 &host[1..host.len() - 1]
301 } else {
302 host
303 }
304 })
305 .ok_or(tungstenite::Error::Url(
306 tungstenite::error::UrlError::NoHostName,
307 ))
308}