1use std::{fmt::Debug, io};
2
3use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
4
5use crate::TlsStream;
6
7#[derive(Clone)]
8enum TlsConnectorInner {
9 #[cfg(feature = "native-tls")]
10 NativeTls(crate::native::TlsConnector),
11 #[cfg(feature = "rustls")]
12 Rustls(futures_rustls::TlsConnector),
13 #[cfg(feature = "py-dynamic-openssl")]
14 PyDynamicOpenSsl(crate::py_ossl::TlsConnector),
15 #[cfg(not(any(
16 feature = "native-tls",
17 feature = "rustls",
18 feature = "py-dynamic-openssl"
19 )))]
20 None(std::convert::Infallible),
21}
22
23impl Debug for TlsConnectorInner {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 #[cfg(feature = "native-tls")]
27 Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
28 #[cfg(feature = "rustls")]
29 Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
30 #[cfg(feature = "py-dynamic-openssl")]
31 Self::PyDynamicOpenSsl(_) => f.debug_tuple("PyDynamicOpenSsl").finish(),
32 #[cfg(not(any(
33 feature = "native-tls",
34 feature = "rustls",
35 feature = "py-dynamic-openssl"
36 )))]
37 Self::None(f) => match *f {},
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
45pub struct TlsConnector(TlsConnectorInner);
46
47#[cfg(feature = "native-tls")]
48impl From<native_tls::TlsConnector> for TlsConnector {
49 fn from(value: native_tls::TlsConnector) -> Self {
50 Self(TlsConnectorInner::NativeTls(value.into()))
51 }
52}
53
54#[cfg(feature = "rustls")]
55impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
56 fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
57 Self(TlsConnectorInner::Rustls(value.into()))
58 }
59}
60
61#[cfg(feature = "py-dynamic-openssl")]
62#[doc(hidden)]
63impl From<compio_py_dynamic_openssl::SSLContext> for TlsConnector {
64 fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
65 Self(TlsConnectorInner::PyDynamicOpenSsl(value.into()))
66 }
67}
68
69impl TlsConnector {
70 pub async fn connect<S: Splittable + 'static>(
83 &self,
84 domain: &str,
85 stream: S,
86 ) -> io::Result<TlsStream<S>>
87 where
88 S::ReadHalf: AsyncRead + Unpin,
89 S::WriteHalf: AsyncWrite + Unpin,
90 {
91 self.connect_compat(domain, AsyncStream::new(stream)).await
92 }
93
94 pub async fn connect_compat<S: Splittable + 'static>(
97 &self,
98 domain: &str,
99 stream: AsyncStream<S>,
100 ) -> io::Result<TlsStream<S>>
101 where
102 S::ReadHalf: AsyncRead + Unpin,
103 S::WriteHalf: AsyncWrite + Unpin,
104 {
105 match &self.0 {
106 #[cfg(feature = "native-tls")]
107 TlsConnectorInner::NativeTls(c) => {
108 let client = c.connect(domain, Box::pin(stream)).await?;
109 Ok(TlsStream::from(client))
110 }
111 #[cfg(feature = "rustls")]
112 TlsConnectorInner::Rustls(c) => {
113 let client = c
114 .connect(
115 domain.to_string().try_into().map_err(io::Error::other)?,
116 Box::pin(stream),
117 )
118 .await?;
119 Ok(TlsStream::from(client))
120 }
121 #[cfg(feature = "py-dynamic-openssl")]
122 TlsConnectorInner::PyDynamicOpenSsl(c) => {
123 let client = c.connect(domain, Box::pin(stream)).await?;
124 Ok(TlsStream::from(client))
125 }
126 #[cfg(not(any(
127 feature = "native-tls",
128 feature = "rustls",
129 feature = "py-dynamic-openssl"
130 )))]
131 TlsConnectorInner::None(f) => match *f {},
132 }
133 }
134}
135
136#[derive(Clone)]
137enum TlsAcceptorInner {
138 #[cfg(feature = "native-tls")]
139 NativeTls(crate::native::TlsAcceptor),
140 #[cfg(feature = "rustls")]
141 Rustls(futures_rustls::TlsAcceptor),
142 #[cfg(feature = "py-dynamic-openssl")]
143 PyDynamicOpenSsl(crate::py_ossl::TlsAcceptor),
144 #[cfg(not(any(
145 feature = "native-tls",
146 feature = "rustls",
147 feature = "py-dynamic-openssl"
148 )))]
149 None(std::convert::Infallible),
150}
151
152#[derive(Clone)]
158pub struct TlsAcceptor(TlsAcceptorInner);
159
160#[cfg(feature = "native-tls")]
161impl From<native_tls::TlsAcceptor> for TlsAcceptor {
162 fn from(value: native_tls::TlsAcceptor) -> Self {
163 Self(TlsAcceptorInner::NativeTls(value.into()))
164 }
165}
166
167#[cfg(feature = "rustls")]
168impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
169 fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
170 Self(TlsAcceptorInner::Rustls(value.into()))
171 }
172}
173
174#[cfg(feature = "py-dynamic-openssl")]
175impl From<compio_py_dynamic_openssl::SSLContext> for TlsAcceptor {
176 fn from(value: compio_py_dynamic_openssl::SSLContext) -> Self {
177 Self(TlsAcceptorInner::PyDynamicOpenSsl(value.into()))
178 }
179}
180
181impl TlsAcceptor {
182 pub async fn accept<S: Splittable + 'static>(&self, stream: S) -> io::Result<TlsStream<S>>
193 where
194 S::ReadHalf: AsyncRead + Unpin,
195 S::WriteHalf: AsyncWrite + Unpin,
196 {
197 self.accept_compat(AsyncStream::new(stream)).await
198 }
199
200 pub async fn accept_compat<S: Splittable + 'static>(
203 &self,
204 stream: AsyncStream<S>,
205 ) -> io::Result<TlsStream<S>>
206 where
207 S::ReadHalf: AsyncRead + Unpin,
208 S::WriteHalf: AsyncWrite + Unpin,
209 {
210 match &self.0 {
211 #[cfg(feature = "native-tls")]
212 TlsAcceptorInner::NativeTls(c) => {
213 let server = c.accept(Box::pin(stream)).await?;
214 Ok(TlsStream::from(server))
215 }
216 #[cfg(feature = "rustls")]
217 TlsAcceptorInner::Rustls(c) => {
218 let server = c.accept(Box::pin(stream)).await?;
219 Ok(TlsStream::from(server))
220 }
221 #[cfg(feature = "py-dynamic-openssl")]
222 TlsAcceptorInner::PyDynamicOpenSsl(a) => {
223 let server = a.accept(Box::pin(stream)).await?;
224 Ok(TlsStream::from(server))
225 }
226 #[cfg(not(any(
227 feature = "native-tls",
228 feature = "rustls",
229 feature = "py-dynamic-openssl"
230 )))]
231 TlsAcceptorInner::None(f) => match *f {},
232 }
233 }
234}