Skip to main content

compio_tls/
rtls.rs

1use std::{
2    io,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
9use futures_util::FutureExt;
10use rustls::{
11    ServerConfig, ServerConnection,
12    server::{Acceptor, ClientHello},
13};
14
15use crate::TlsStream;
16
17/// A lazy TLS acceptor that performs the initial handshake and allows access to
18/// the [`ClientHello`] message before completing the handshake.
19pub struct LazyConfigAcceptor<S: Splittable>(
20    futures_rustls::LazyConfigAcceptor<Pin<Box<AsyncStream<S>>>>,
21);
22
23impl<S: Splittable + 'static> LazyConfigAcceptor<S>
24where
25    S::ReadHalf: AsyncRead + Unpin,
26    S::WriteHalf: AsyncWrite + Unpin,
27{
28    /// Create a new [`LazyConfigAcceptor`] with the given acceptor and stream.
29    pub fn new(acceptor: Acceptor, s: S) -> Self {
30        Self(futures_rustls::LazyConfigAcceptor::new(
31            acceptor,
32            Box::pin(AsyncStream::new(s)),
33        ))
34    }
35}
36
37impl<S: Splittable + 'static> Future for LazyConfigAcceptor<S>
38where
39    S::ReadHalf: AsyncRead + Unpin,
40    S::WriteHalf: AsyncWrite + Unpin,
41{
42    type Output = Result<StartHandshake<S>, io::Error>;
43
44    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45        self.0.poll_unpin(cx).map_ok(StartHandshake)
46    }
47}
48
49/// A TLS acceptor that has completed the initial handshake and allows access to
50/// the [`ClientHello`] message.
51pub struct StartHandshake<S: Splittable>(futures_rustls::StartHandshake<Pin<Box<AsyncStream<S>>>>);
52
53impl<S: Splittable + 'static> StartHandshake<S>
54where
55    S::ReadHalf: AsyncRead + Unpin,
56    S::WriteHalf: AsyncWrite + Unpin,
57{
58    /// Get the [`ClientHello`] message from the initial handshake.
59    pub fn client_hello(&self) -> ClientHello<'_> {
60        self.0.client_hello()
61    }
62
63    /// Complete the TLS handshake and return a [`TlsStream`] if successful.
64    pub fn into_stream(
65        self,
66        config: Arc<ServerConfig>,
67    ) -> impl Future<Output = io::Result<TlsStream<S>>> {
68        self.into_stream_with(config, |_| ())
69    }
70
71    /// Complete the TLS handshake and return a [`TlsStream`] if successful.
72    pub fn into_stream_with<F>(
73        self,
74        config: Arc<ServerConfig>,
75        f: F,
76    ) -> impl Future<Output = io::Result<TlsStream<S>>>
77    where
78        F: FnOnce(&mut ServerConnection),
79    {
80        self.0
81            .into_stream_with(config, f)
82            .map(|res| res.map(TlsStream::from))
83    }
84}