1use std::{fmt::Debug, io};
2
3use compio_io::{
4 AsyncRead, AsyncWrite,
5 compat::{AsyncStream, SyncStream},
6};
7
8use crate::TlsStream;
9
10#[derive(Clone)]
11enum TlsConnectorInner {
12 #[cfg(feature = "native-tls")]
13 NativeTls(native_tls::TlsConnector),
14 #[cfg(feature = "rustls")]
15 Rustls(futures_rustls::TlsConnector),
16 #[cfg(feature = "py-dynamic-openssl")]
17 PyDynamicOpenSsl(compio_py_dynamic_openssl::SSLContext),
18 #[cfg(not(any(
19 feature = "native-tls",
20 feature = "rustls",
21 feature = "py-dynamic-openssl"
22 )))]
23 None(std::convert::Infallible),
24}
25
26impl Debug for TlsConnectorInner {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 #[cfg(feature = "native-tls")]
30 Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
31 #[cfg(feature = "rustls")]
32 Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
33 #[cfg(feature = "py-dynamic-openssl")]
34 Self::PyDynamicOpenSsl(_) => f.debug_tuple("PyDynamicOpenSsl").finish(),
35 #[cfg(not(any(
36 feature = "native-tls",
37 feature = "rustls",
38 feature = "py-dynamic-openssl"
39 )))]
40 Self::None(f) => match *f {},
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
48pub struct TlsConnector(TlsConnectorInner);
49
50#[cfg(feature = "native-tls")]
51impl From<native_tls::TlsConnector> for TlsConnector {
52 fn from(value: native_tls::TlsConnector) -> Self {
53 Self(TlsConnectorInner::NativeTls(value))
54 }
55}
56
57#[cfg(feature = "rustls")]
58impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
59 fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
60 Self(TlsConnectorInner::Rustls(value.into()))
61 }
62}
63
64#[cfg(feature = "py-dynamic-openssl")]
65#[doc(hidden)]
66impl From<compio_py_dynamic_openssl::SSLContext> for TlsConnector {
67 fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
68 Self(TlsConnectorInner::PyDynamicOpenSsl(value))
69 }
70}
71
72impl TlsConnector {
73 pub async fn connect<S: AsyncRead + AsyncWrite + Unpin + 'static>(
86 &self,
87 domain: &str,
88 stream: S,
89 ) -> io::Result<TlsStream<S>>
90 where
91 for<'a> &'a S: AsyncRead + AsyncWrite,
92 {
93 match &self.0 {
94 #[cfg(feature = "native-tls")]
95 TlsConnectorInner::NativeTls(c) => {
96 handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
97 }
98 #[cfg(feature = "rustls")]
99 TlsConnectorInner::Rustls(c) => {
100 let client = c
101 .connect(
102 domain.to_string().try_into().map_err(io::Error::other)?,
103 Box::pin(AsyncStream::new(stream)),
104 )
105 .await?;
106 Ok(TlsStream::from(client))
107 }
108 #[cfg(feature = "py-dynamic-openssl")]
109 TlsConnectorInner::PyDynamicOpenSsl(c) => {
110 crate::py_ossl::handshake(c.connect(domain, SyncStream::new(stream))).await
111 }
112 #[cfg(not(any(
113 feature = "native-tls",
114 feature = "rustls",
115 feature = "py-dynamic-openssl"
116 )))]
117 TlsConnectorInner::None(f) => match *f {},
118 }
119 }
120}
121
122#[derive(Clone)]
123enum TlsAcceptorInner {
124 #[cfg(feature = "native-tls")]
125 NativeTls(native_tls::TlsAcceptor),
126 #[cfg(feature = "rustls")]
127 Rustls(futures_rustls::TlsAcceptor),
128 #[cfg(feature = "py-dynamic-openssl")]
129 PyDynamicOpenSsl(compio_py_dynamic_openssl::SSLContext),
130 #[cfg(not(any(
131 feature = "native-tls",
132 feature = "rustls",
133 feature = "py-dynamic-openssl"
134 )))]
135 None(std::convert::Infallible),
136}
137
138#[derive(Clone)]
144pub struct TlsAcceptor(TlsAcceptorInner);
145
146#[cfg(feature = "native-tls")]
147impl From<native_tls::TlsAcceptor> for TlsAcceptor {
148 fn from(value: native_tls::TlsAcceptor) -> Self {
149 Self(TlsAcceptorInner::NativeTls(value))
150 }
151}
152
153#[cfg(feature = "rustls")]
154impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
155 fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
156 Self(TlsAcceptorInner::Rustls(value.into()))
157 }
158}
159
160#[cfg(feature = "py-dynamic-openssl")]
161impl From<compio_py_dynamic_openssl::SSLContext> for TlsAcceptor {
162 fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
163 Self(TlsAcceptorInner::PyDynamicOpenSsl(value))
164 }
165}
166
167impl TlsAcceptor {
168 pub async fn accept<S: AsyncRead + AsyncWrite + Unpin + 'static>(
179 &self,
180 stream: S,
181 ) -> io::Result<TlsStream<S>>
182 where
183 for<'a> &'a S: AsyncRead + AsyncWrite,
184 {
185 match &self.0 {
186 #[cfg(feature = "native-tls")]
187 TlsAcceptorInner::NativeTls(c) => {
188 handshake_native_tls(c.accept(SyncStream::new(stream))).await
189 }
190 #[cfg(feature = "rustls")]
191 TlsAcceptorInner::Rustls(c) => {
192 let server = c.accept(Box::pin(AsyncStream::new(stream))).await?;
193 Ok(TlsStream::from(server))
194 }
195 #[cfg(feature = "py-dynamic-openssl")]
196 TlsAcceptorInner::PyDynamicOpenSsl(a) => {
197 crate::py_ossl::handshake(a.accept(SyncStream::new(stream))).await
198 }
199 #[cfg(not(any(
200 feature = "native-tls",
201 feature = "rustls",
202 feature = "py-dynamic-openssl"
203 )))]
204 TlsAcceptorInner::None(f) => match *f {},
205 }
206 }
207}
208
209#[cfg(feature = "native-tls")]
210async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
211 mut res: Result<
212 native_tls::TlsStream<SyncStream<S>>,
213 native_tls::HandshakeError<SyncStream<S>>,
214 >,
215) -> io::Result<TlsStream<S>> {
216 use native_tls::HandshakeError;
217
218 loop {
219 match res {
220 Ok(mut s) => {
221 let inner = s.get_mut();
222 if inner.has_pending_write() {
223 inner.flush_write_buf().await?;
224 }
225 return Ok(TlsStream::from(s));
226 }
227 Err(e) => match e {
228 HandshakeError::Failure(e) => return Err(io::Error::other(e)),
229 HandshakeError::WouldBlock(mut mid_stream) => {
230 let s = mid_stream.get_mut();
231 if s.has_pending_write() {
232 s.flush_write_buf().await?;
233 } else {
234 s.fill_read_buf().await?;
235 }
236 res = mid_stream.handshake();
237 }
238 },
239 }
240 }
241}