Skip to main content

compio_tls/
maybe.rs

1use std::{
2    borrow::Cow,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf};
9use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
10
11use crate::{TlsStream, read_futures};
12
13#[derive(Debug)]
14#[allow(clippy::large_enum_variant)]
15enum MaybeTlsStreamInner<S: Splittable> {
16    Plain(Pin<Box<AsyncStream<S>>>),
17    Tls(TlsStream<S>),
18}
19
20/// Stream that can be either plain TCP or TLS-encrypted, with compatibility for
21/// [`futures_util`].
22#[derive(Debug)]
23pub struct MaybeTlsStream<S: Splittable>(MaybeTlsStreamInner<S>);
24
25impl<S: Splittable> MaybeTlsStream<S> {
26    /// Create an unencrypted stream.
27    pub fn new_plain(stream: S) -> Self {
28        Self(MaybeTlsStreamInner::Plain(Box::pin(AsyncStream::new(
29            stream,
30        ))))
31    }
32
33    /// Create an unencrypted stream from [`AsyncStream`].
34    pub fn new_plain_compat(stream: AsyncStream<S>) -> Self {
35        Self(MaybeTlsStreamInner::Plain(Box::pin(stream)))
36    }
37
38    /// Create a TLS-encrypted stream.
39    pub fn new_tls(stream: TlsStream<S>) -> Self {
40        Self(MaybeTlsStreamInner::Tls(stream))
41    }
42
43    /// Whether the stream is TLS-encrypted.
44    pub fn is_tls(&self) -> bool {
45        matches!(self.0, MaybeTlsStreamInner::Tls(_))
46    }
47}
48
49impl<S: Splittable + 'static> MaybeTlsStream<S>
50where
51    S::ReadHalf: AsyncRead + Unpin,
52    S::WriteHalf: AsyncWrite + Unpin,
53{
54    /// Returns the negotiated ALPN protocol.
55    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
56        match &self.0 {
57            MaybeTlsStreamInner::Plain(_) => None,
58            MaybeTlsStreamInner::Tls(s) => s.negotiated_alpn(),
59        }
60    }
61}
62
63impl<S: Splittable + 'static> futures_util::AsyncRead for MaybeTlsStream<S>
64where
65    S::ReadHalf: AsyncRead + Unpin,
66    S::WriteHalf: AsyncWrite + Unpin,
67{
68    fn poll_read(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &mut [u8],
72    ) -> Poll<io::Result<usize>> {
73        match &mut self.get_mut().0 {
74            MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
75            MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
76        }
77    }
78}
79
80impl<S: Splittable + 'static> AsyncRead for MaybeTlsStream<S>
81where
82    S::ReadHalf: AsyncRead + Unpin,
83    S::WriteHalf: AsyncWrite + Unpin,
84{
85    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
86        read_futures(self, buf).await
87    }
88}
89
90impl<S: Splittable + 'static> futures_util::AsyncWrite for MaybeTlsStream<S>
91where
92    S::ReadHalf: AsyncRead + Unpin,
93    S::WriteHalf: AsyncWrite + Unpin,
94{
95    fn poll_write(
96        self: Pin<&mut Self>,
97        cx: &mut Context<'_>,
98        buf: &[u8],
99    ) -> Poll<io::Result<usize>> {
100        match &mut self.get_mut().0 {
101            MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
102            MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
103        }
104    }
105
106    fn poll_write_vectored(
107        self: Pin<&mut Self>,
108        cx: &mut Context<'_>,
109        bufs: &[io::IoSlice<'_>],
110    ) -> Poll<io::Result<usize>> {
111        match &mut self.get_mut().0 {
112            MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
113            MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
114        }
115    }
116
117    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
118        match &mut self.get_mut().0 {
119            MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_flush(cx),
120            MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_flush(cx),
121        }
122    }
123
124    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
125        match &mut self.get_mut().0 {
126            MaybeTlsStreamInner::Plain(stream) => Pin::new(stream).poll_close(cx),
127            MaybeTlsStreamInner::Tls(stream) => Pin::new(stream).poll_close(cx),
128        }
129    }
130}
131
132impl<S: Splittable + 'static> AsyncWrite for MaybeTlsStream<S>
133where
134    S::ReadHalf: AsyncRead + Unpin,
135    S::WriteHalf: AsyncWrite + Unpin,
136{
137    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
138        let slice = buf.as_init();
139        let res = futures_util::AsyncWriteExt::write(self, slice).await;
140        BufResult(res, buf)
141    }
142
143    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
144        let slices = buf.iter_slice().map(io::IoSlice::new).collect::<Vec<_>>();
145        let res = futures_util::AsyncWriteExt::write_vectored(self, &slices).await;
146        BufResult(res, buf)
147    }
148
149    async fn flush(&mut self) -> io::Result<()> {
150        futures_util::AsyncWriteExt::flush(self).await
151    }
152
153    async fn shutdown(&mut self) -> io::Result<()> {
154        futures_util::AsyncWriteExt::close(self).await
155    }
156}