Skip to main content

compio_tls/
maybe.rs

1use std::{borrow::Cow, io};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{AsyncRead, AsyncWrite};
5
6use crate::TlsStream;
7
8#[derive(Debug)]
9#[allow(clippy::large_enum_variant)]
10enum MaybeTlsStreamInner<S> {
11    /// Plain, unencrypted stream
12    Plain(S),
13    /// TLS-encrypted stream
14    Tls(TlsStream<S>),
15}
16
17/// Stream that can be either plain TCP or TLS-encrypted
18#[derive(Debug)]
19pub struct MaybeTlsStream<S>(MaybeTlsStreamInner<S>);
20
21impl<S> MaybeTlsStream<S> {
22    /// Create an unencrypted stream.
23    pub fn new_plain(stream: S) -> Self {
24        Self(MaybeTlsStreamInner::Plain(stream))
25    }
26
27    /// Create a TLS-encrypted stream.
28    pub fn new_tls(stream: TlsStream<S>) -> Self {
29        Self(MaybeTlsStreamInner::Tls(stream))
30    }
31
32    /// Whether the stream is TLS-encrypted.
33    pub fn is_tls(&self) -> bool {
34        matches!(self.0, MaybeTlsStreamInner::Tls(_))
35    }
36
37    /// Returns the negotiated ALPN protocol.
38    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
39        match &self.0 {
40            MaybeTlsStreamInner::Plain(_) => None,
41            MaybeTlsStreamInner::Tls(s) => s.negotiated_alpn(),
42        }
43    }
44}
45
46impl<S> AsyncRead for MaybeTlsStream<S>
47where
48    S: AsyncRead + AsyncWrite + Unpin + 'static,
49    for<'a> &'a S: AsyncRead + AsyncWrite,
50{
51    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
52        match &mut self.0 {
53            MaybeTlsStreamInner::Plain(stream) => stream.read(buf).await,
54            MaybeTlsStreamInner::Tls(stream) => stream.read(buf).await,
55        }
56    }
57}
58
59impl<S> AsyncWrite for MaybeTlsStream<S>
60where
61    S: AsyncRead + AsyncWrite + Unpin + 'static,
62    for<'a> &'a S: AsyncRead + AsyncWrite,
63{
64    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
65        match &mut self.0 {
66            MaybeTlsStreamInner::Plain(stream) => stream.write(buf).await,
67            MaybeTlsStreamInner::Tls(stream) => stream.write(buf).await,
68        }
69    }
70
71    async fn flush(&mut self) -> io::Result<()> {
72        match &mut self.0 {
73            MaybeTlsStreamInner::Plain(stream) => stream.flush().await,
74            MaybeTlsStreamInner::Tls(stream) => stream.flush().await,
75        }
76    }
77
78    async fn shutdown(&mut self) -> io::Result<()> {
79        match &mut self.0 {
80            MaybeTlsStreamInner::Plain(stream) => stream.shutdown().await,
81            MaybeTlsStreamInner::Tls(stream) => stream.shutdown().await,
82        }
83    }
84}