1use std::{
2 borrow::Cow,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf};
9use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
10
11use crate::{TlsStream, read_futures};
12
13#[derive(Debug)]
14#[allow(clippy::large_enum_variant)]
15enum MaybeTlsStreamInner<S: Splittable> {
16 Plain(Pin<Box<AsyncStream<S>>>),
17 Tls(TlsStream<S>),
18}
19
20#[derive(Debug)]
23pub struct MaybeTlsStream<S: Splittable>(MaybeTlsStreamInner<S>);
24
25impl<S: Splittable> MaybeTlsStream<S> {
26 pub fn new_plain(stream: S) -> Self {
28 Self(MaybeTlsStreamInner::Plain(Box::pin(AsyncStream::new(
29 stream,
30 ))))
31 }
32
33 pub fn new_plain_compat(stream: AsyncStream<S>) -> Self {
35 Self(MaybeTlsStreamInner::Plain(Box::pin(stream)))
36 }
37
38 pub fn new_tls(stream: TlsStream<S>) -> Self {
40 Self(MaybeTlsStreamInner::Tls(stream))
41 }
42
43 pub fn is_tls(&self) -> bool {
45 matches!(self.0, MaybeTlsStreamInner::Tls(_))
46 }
47}
48
49impl<S: Splittable + 'static> MaybeTlsStream<S>
50where
51 S::ReadHalf: AsyncRead + Unpin,
52 S::WriteHalf: AsyncWrite + Unpin,
53{
54 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
56 match &self.0 {
57 MaybeTlsStreamInner::Plain(_) => None,
58 MaybeTlsStreamInner::Tls(s) => s.negotiated_alpn(),
59 }
60 }
61}
62
63impl<S: Splittable + 'static> futures_util::AsyncRead for MaybeTlsStream<S>
64where
65 S::ReadHalf: AsyncRead + Unpin,
66 S::WriteHalf: AsyncWrite + Unpin,
67{
68 fn poll_read(
69 self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &mut [u8],
72 ) -> Poll<io::Result<usize>> {
73 match &mut self.get_mut().0 {
74 MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
75 MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
76 }
77 }
78}
79
80impl<S: Splittable + 'static> AsyncRead for MaybeTlsStream<S>
81where
82 S::ReadHalf: AsyncRead + Unpin,
83 S::WriteHalf: AsyncWrite + Unpin,
84{
85 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
86 read_futures(self, buf).await
87 }
88}
89
90impl<S: Splittable + 'static> futures_util::AsyncWrite for MaybeTlsStream<S>
91where
92 S::ReadHalf: AsyncRead + Unpin,
93 S::WriteHalf: AsyncWrite + Unpin,
94{
95 fn poll_write(
96 self: Pin<&mut Self>,
97 cx: &mut Context<'_>,
98 buf: &[u8],
99 ) -> Poll<io::Result<usize>> {
100 match &mut self.get_mut().0 {
101 MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
102 MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
103 }
104 }
105
106 fn poll_write_vectored(
107 self: Pin<&mut Self>,
108 cx: &mut Context<'_>,
109 bufs: &[io::IoSlice<'_>],
110 ) -> Poll<io::Result<usize>> {
111 match &mut self.get_mut().0 {
112 MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
113 MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
114 }
115 }
116
117 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
118 match &mut self.get_mut().0 {
119 MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_flush(cx),
120 MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_flush(cx),
121 }
122 }
123
124 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
125 match &mut self.get_mut().0 {
126 MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_close(cx),
127 MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_close(cx),
128 }
129 }
130}
131
132impl<S: Splittable + 'static> AsyncWrite for MaybeTlsStream<S>
133where
134 S::ReadHalf: AsyncRead + Unpin,
135 S::WriteHalf: AsyncWrite + Unpin,
136{
137 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
138 let slice = buf.as_init();
139 let res = futures_util::AsyncWriteExt::write(self, slice).await;
140 BufResult(res, buf)
141 }
142
143 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
144 let slices = buf.iter_slice().map(io::IoSlice::new).collect::<Vec<_>>();
145 let res = futures_util::AsyncWriteExt::write_vectored(self, &slices).await;
146 BufResult(res, buf)
147 }
148
149 async fn flush(&mut self) -> io::Result<()> {
150 futures_util::AsyncWriteExt::flush(self).await
151 }
152
153 async fn shutdown(&mut self) -> io::Result<()> {
154 futures_util::AsyncWriteExt::close(self).await
155 }
156}