1#[cfg(feature = "rustls")]
2use std::pin::Pin;
3use std::{borrow::Cow, io, mem::MaybeUninit};
4
5use compio_buf::{BufResult, IoBuf, IoBufMut};
6use compio_io::{
7 AsyncRead, AsyncWrite,
8 compat::{AsyncStream, SyncStream},
9};
10
11#[derive(Debug)]
12#[allow(clippy::large_enum_variant)]
13enum TlsStreamInner<S> {
14 #[cfg(feature = "native-tls")]
15 NativeTls(native_tls::TlsStream<SyncStream<S>>),
16 #[cfg(feature = "rustls")]
17 Rustls(futures_rustls::TlsStream<Pin<Box<AsyncStream<S>>>>),
18 #[cfg(feature = "py-dynamic-openssl")]
19 PyDynamicOpenSsl(compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>),
20 #[cfg(not(any(
21 feature = "native-tls",
22 feature = "rustls",
23 feature = "py-dynamic-openssl",
24 )))]
25 None(std::convert::Infallible, std::marker::PhantomData<S>),
26}
27
28impl<S> TlsStreamInner<S> {
29 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
30 match self {
31 #[cfg(feature = "native-tls")]
32 Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
33 #[cfg(feature = "rustls")]
34 Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from),
35 #[cfg(feature = "py-dynamic-openssl")]
36 Self::PyDynamicOpenSsl(s) => crate::py_ossl::negotiated_alpn(s),
37 #[cfg(not(any(
38 feature = "native-tls",
39 feature = "rustls",
40 feature = "py-dynamic-openssl",
41 )))]
42 Self::None(f, ..) => match *f {},
43 }
44 }
45}
46
47#[derive(Debug)]
55pub struct TlsStream<S>(TlsStreamInner<S>);
56
57impl<S> TlsStream<S> {
58 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
60 self.0.negotiated_alpn()
61 }
62}
63
64#[cfg(feature = "native-tls")]
65#[doc(hidden)]
66impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
67 fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
68 Self(TlsStreamInner::NativeTls(value))
69 }
70}
71
72#[cfg(feature = "rustls")]
73#[doc(hidden)]
74impl<S> From<futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
75 fn from(value: futures_rustls::client::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
76 Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client(
77 value,
78 )))
79 }
80}
81
82#[cfg(feature = "rustls")]
83#[doc(hidden)]
84impl<S> From<futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>> for TlsStream<S> {
85 fn from(value: futures_rustls::server::TlsStream<Pin<Box<AsyncStream<S>>>>) -> Self {
86 Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server(
87 value,
88 )))
89 }
90}
91
92#[cfg(feature = "py-dynamic-openssl")]
93#[doc(hidden)]
94impl<S> From<compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>> for TlsStream<S> {
95 fn from(value: compio_py_dynamic_openssl::ssl::SslStream<SyncStream<S>>) -> Self {
96 Self(TlsStreamInner::PyDynamicOpenSsl(value))
97 }
98}
99
100#[cfg(feature = "native-tls")]
101#[inline]
102async fn drive<S, F, T>(s: &mut native_tls::TlsStream<SyncStream<S>>, mut f: F) -> io::Result<T>
103where
104 S: AsyncRead + AsyncWrite,
105 F: FnMut(&mut native_tls::TlsStream<SyncStream<S>>) -> io::Result<T>,
106{
107 loop {
108 match f(s) {
109 Ok(res) => {
110 let s = s.get_mut();
111 if s.has_pending_write() {
112 s.flush_write_buf().await?;
113 }
114 break Ok(res);
115 }
116 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
117 let s = s.get_mut();
118 if s.has_pending_write() {
119 s.flush_write_buf().await?;
120 } else {
121 s.fill_read_buf().await?;
122 }
123 }
124 Err(e) => break Err(e),
125 }
126 }
127}
128
129impl<S: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for TlsStream<S>
130where
131 for<'a> &'a S: AsyncRead + AsyncWrite,
132{
133 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
134 let slice = buf.as_uninit();
135 slice.fill(MaybeUninit::new(0));
136 let slice =
138 unsafe { std::slice::from_raw_parts_mut::<u8>(slice.as_mut_ptr().cast(), slice.len()) };
139 match &mut self.0 {
140 #[cfg(feature = "native-tls")]
141 TlsStreamInner::NativeTls(s) => match drive(s, |s| io::Read::read(s, slice)).await {
142 Ok(res) => {
143 unsafe { buf.advance_to(res) };
144 BufResult(Ok(res), buf)
145 }
146 res => BufResult(res, buf),
147 },
148 #[cfg(feature = "rustls")]
149 TlsStreamInner::Rustls(s) => {
150 let res = futures_util::AsyncReadExt::read(s, slice).await;
151 let res = match res {
152 Ok(len) => {
153 unsafe { buf.advance_to(len) };
154 Ok(len)
155 }
156 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
159 _ => res,
160 };
161 BufResult(res, buf)
162 }
163 #[cfg(feature = "py-dynamic-openssl")]
164 TlsStreamInner::PyDynamicOpenSsl(s) => match crate::py_ossl::read(s, slice).await {
165 Ok(res) => {
166 unsafe { buf.advance_to(res) };
167 BufResult(Ok(res), buf)
168 }
169 Err(e) => BufResult(Err(e), buf),
170 },
171 #[cfg(not(any(
172 feature = "native-tls",
173 feature = "rustls",
174 feature = "py-dynamic-openssl",
175 )))]
176 TlsStreamInner::None(f, ..) => match *f {},
177 }
178 }
179}
180
181#[cfg(feature = "native-tls")]
182async fn flush_impl(s: &mut native_tls::TlsStream<SyncStream<impl AsyncWrite>>) -> io::Result<()> {
183 loop {
184 match io::Write::flush(s) {
185 Ok(()) => break,
186 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
187 s.get_mut().flush_write_buf().await?;
188 }
189 Err(e) => return Err(e),
190 }
191 }
192 s.get_mut().flush_write_buf().await?;
193 Ok(())
194}
195
196impl<S: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for TlsStream<S>
197where
198 for<'a> &'a S: AsyncRead + AsyncWrite,
199{
200 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
201 let slice = buf.as_init();
202 match &mut self.0 {
203 #[cfg(feature = "native-tls")]
204 TlsStreamInner::NativeTls(s) => {
205 let res = drive(s, |s| io::Write::write(s, slice)).await;
206 BufResult(res, buf)
207 }
208 #[cfg(feature = "rustls")]
209 TlsStreamInner::Rustls(s) => {
210 let res = futures_util::AsyncWriteExt::write(s, slice).await;
211 BufResult(res, buf)
212 }
213 #[cfg(feature = "py-dynamic-openssl")]
214 TlsStreamInner::PyDynamicOpenSsl(s) => {
215 let res = crate::py_ossl::write(s, slice).await;
216 BufResult(res, buf)
217 }
218 #[cfg(not(any(
219 feature = "native-tls",
220 feature = "rustls",
221 feature = "py-dynamic-openssl",
222 )))]
223 TlsStreamInner::None(f, ..) => match *f {},
224 }
225 }
226
227 async fn flush(&mut self) -> io::Result<()> {
228 match &mut self.0 {
229 #[cfg(feature = "native-tls")]
230 TlsStreamInner::NativeTls(s) => flush_impl(s).await,
231 #[cfg(feature = "rustls")]
232 TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::flush(s).await,
233 #[cfg(feature = "py-dynamic-openssl")]
234 TlsStreamInner::PyDynamicOpenSsl(s) => s.get_mut().flush_write_buf().await.map(|_| ()),
235 #[cfg(not(any(
236 feature = "native-tls",
237 feature = "rustls",
238 feature = "py-dynamic-openssl",
239 )))]
240 TlsStreamInner::None(f, ..) => match *f {},
241 }
242 }
243
244 async fn shutdown(&mut self) -> io::Result<()> {
245 self.flush().await?;
246 match &mut self.0 {
247 #[cfg(feature = "native-tls")]
248 TlsStreamInner::NativeTls(s) => {
249 drive(s, |s| s.shutdown()).await?;
257 s.get_mut().get_mut().shutdown().await
258 }
259 #[cfg(feature = "rustls")]
260 TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::close(s).await,
261 #[cfg(feature = "py-dynamic-openssl")]
262 TlsStreamInner::PyDynamicOpenSsl(s) => crate::py_ossl::shutdown(s).await,
263 #[cfg(not(any(
264 feature = "native-tls",
265 feature = "rustls",
266 feature = "py-dynamic-openssl",
267 )))]
268 TlsStreamInner::None(f, ..) => match *f {},
269 }
270 }
271}