Skip to main content

compio_tls/
rtls.rs

1use std::{
2    io,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream};
9use futures_util::FutureExt;
10use rustls::{
11    ServerConfig, ServerConnection,
12    server::{Acceptor, ClientHello},
13};
14
15use crate::TlsStream;
16
17/// A lazy TLS acceptor that performs the initial handshake and allows access to
18/// the [`ClientHello`] message before completing the handshake.
19pub struct LazyConfigAcceptor<S>(futures_rustls::LazyConfigAcceptor<Pin<Box<AsyncStream<S>>>>);
20
21impl<S: AsyncRead + AsyncWrite + Unpin + 'static> LazyConfigAcceptor<S>
22where
23    for<'a> &'a S: AsyncRead + AsyncWrite,
24{
25    /// Create a new [`LazyConfigAcceptor`] with the given acceptor and stream.
26    pub fn new(acceptor: Acceptor, s: S) -> Self {
27        Self(futures_rustls::LazyConfigAcceptor::new(
28            acceptor,
29            Box::pin(AsyncStream::new(s)),
30        ))
31    }
32}
33
34impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Future for LazyConfigAcceptor<S>
35where
36    for<'a> &'a S: AsyncRead + AsyncWrite,
37{
38    type Output = Result<StartHandshake<S>, io::Error>;
39
40    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
41        self.0.poll_unpin(cx).map_ok(StartHandshake)
42    }
43}
44
45/// A TLS acceptor that has completed the initial handshake and allows access to
46/// the [`ClientHello`] message.
47pub struct StartHandshake<S>(futures_rustls::StartHandshake<Pin<Box<AsyncStream<S>>>>);
48
49impl<S: AsyncRead + AsyncWrite + Unpin + 'static> StartHandshake<S>
50where
51    for<'a> &'a S: AsyncRead + AsyncWrite,
52{
53    /// Get the [`ClientHello`] message from the initial handshake.
54    pub fn client_hello(&self) -> ClientHello<'_> {
55        self.0.client_hello()
56    }
57
58    /// Complete the TLS handshake and return a [`TlsStream`] if successful.
59    pub fn into_stream(
60        self,
61        config: Arc<ServerConfig>,
62    ) -> impl Future<Output = io::Result<TlsStream<S>>> {
63        self.into_stream_with(config, |_| ())
64    }
65
66    /// Complete the TLS handshake and return a [`TlsStream`] if successful.
67    pub fn into_stream_with<F>(
68        self,
69        config: Arc<ServerConfig>,
70        f: F,
71    ) -> impl Future<Output = io::Result<TlsStream<S>>>
72    where
73        F: FnOnce(&mut ServerConnection),
74    {
75        self.0
76            .into_stream_with(config, f)
77            .map(|res| res.map(TlsStream::from))
78    }
79}